From 3870a96592c6a93d3a68391f6cbaecd9c7a2bc97 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 16 Oct 2024 03:15:48 +0200 Subject: Re-enable all failing PTX tests (#277) Additionally remove unused compilation paths --- .devcontainer/Dockerfile | 37 + .devcontainer/devcontainer.json | 34 + Cargo.toml | 4 - comgr/src/lib.rs | 94 +- llvm_zluda/src/lib.cpp | 182 + llvm_zluda/src/lib.rs | 71 + ptx/Cargo.toml | 12 - ptx/build.rs | 5 - ptx/lib/zluda_ptx_impl.bc | Bin 34052 -> 4624 bytes ptx/lib/zluda_ptx_impl.cl | 344 - ptx/lib/zluda_ptx_impl.cpp | 151 + ptx/lib/zluda_ptx_impl.spv | Bin 106076 -> 0 bytes ptx/src/ast.rs | 1074 --- ptx/src/lib.rs | 182 +- .../pass/convert_dynamic_shared_memory_usage.rs | 299 - ptx/src/pass/convert_to_stateful_memory_access.rs | 524 -- ptx/src/pass/convert_to_typed.rs | 138 - ptx/src/pass/deparamize_functions.rs | 185 +- ptx/src/pass/emit_llvm.rs | 2159 +++++- ptx/src/pass/emit_spirv.rs | 2762 ------- ptx/src/pass/expand_arguments.rs | 181 - ptx/src/pass/expand_operands.rs | 50 +- ptx/src/pass/extract_globals.rs | 281 - ptx/src/pass/fix_special_registers.rs | 130 - ptx/src/pass/fix_special_registers2.rs | 37 +- ptx/src/pass/hoist_globals.rs | 2 +- ptx/src/pass/insert_explicit_load_store.rs | 90 +- ptx/src/pass/insert_implicit_conversions.rs | 438 -- ptx/src/pass/insert_mem_ssa_statements.rs | 275 - ptx/src/pass/mod.rs | 1493 +--- ptx/src/pass/normalize_identifiers.rs | 80 - ptx/src/pass/normalize_identifiers2.rs | 3 +- ptx/src/pass/normalize_labels.rs | 49 - ptx/src/pass/normalize_predicates.rs | 44 - .../replace_instructions_with_function_calls.rs | 187 + ptx/src/ptx.lalrpop | 2198 ------ ptx/src/test/mod.rs | 17 +- ptx/src/test/spirv_run/activemask.spvtxt | 45 - ptx/src/test/spirv_run/add.spvtxt | 47 - ptx/src/test/spirv_run/add_non_coherent.spvtxt | 47 - ptx/src/test/spirv_run/add_tuning.spvtxt | 55 - ptx/src/test/spirv_run/and.spvtxt | 62 - ptx/src/test/spirv_run/assertfail.spvtxt | 105 - ptx/src/test/spirv_run/atom_add.spvtxt | 85 - ptx/src/test/spirv_run/atom_add_float.spvtxt | 90 - ptx/src/test/spirv_run/atom_cas.spvtxt | 77 - ptx/src/test/spirv_run/atom_inc.spvtxt | 87 - ptx/src/test/spirv_run/b64tof64.spvtxt | 50 - ptx/src/test/spirv_run/bfe.spvtxt | 76 - ptx/src/test/spirv_run/bfi.spvtxt | 90 - ptx/src/test/spirv_run/block.spvtxt | 52 - ptx/src/test/spirv_run/bra.spvtxt | 57 - ptx/src/test/spirv_run/brev.spvtxt | 52 - ptx/src/test/spirv_run/call.spvtxt | 71 - ptx/src/test/spirv_run/clz.spvtxt | 52 - ptx/src/test/spirv_run/const.spvtxt | 112 - ptx/src/test/spirv_run/constant_f32.spvtxt | 48 - ptx/src/test/spirv_run/constant_negative.spvtxt | 48 - ptx/src/test/spirv_run/cos.spvtxt | 48 - ptx/src/test/spirv_run/cvt_f64_f32.spvtxt | 55 - ptx/src/test/spirv_run/cvt_rni.spvtxt | 69 - ptx/src/test/spirv_run/cvt_rzi.spvtxt | 69 - ptx/src/test/spirv_run/cvt_s16_s8.spvtxt | 59 - ptx/src/test/spirv_run/cvt_s32_f32.spvtxt | 82 - ptx/src/test/spirv_run/cvt_s64_s32.spvtxt | 55 - ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt | 56 - ptx/src/test/spirv_run/cvta.spvtxt | 65 - ptx/src/test/spirv_run/div_approx.spvtxt | 60 - ptx/src/test/spirv_run/ex2.spvtxt | 48 - ptx/src/test/spirv_run/extern_func.spvtxt | 75 - ptx/src/test/spirv_run/extern_shared.spvtxt | 56 - ptx/src/test/spirv_run/extern_shared_call.spvtxt | 75 - ptx/src/test/spirv_run/fma.spvtxt | 69 - ptx/src/test/spirv_run/func_ptr.spvtxt | 77 - ptx/src/test/spirv_run/global_array.spvtxt | 53 - ptx/src/test/spirv_run/implicit_param.spvtxt | 53 - ptx/src/test/spirv_run/lanemask_lt.spvtxt | 70 - ptx/src/test/spirv_run/ld_st.spvtxt | 42 - ptx/src/test/spirv_run/ld_st_implicit.spvtxt | 56 - ptx/src/test/spirv_run/ld_st_offset.spvtxt | 63 - ptx/src/test/spirv_run/lg2.spvtxt | 48 - ptx/src/test/spirv_run/local_align.spvtxt | 49 - ptx/src/test/spirv_run/mad_s32.spvtxt | 87 - ptx/src/test/spirv_run/max.spvtxt | 59 - ptx/src/test/spirv_run/membar.spvtxt | 49 - ptx/src/test/spirv_run/min.spvtxt | 59 - ptx/src/test/spirv_run/mod.rs | 251 +- ptx/src/test/spirv_run/mov.spvtxt | 46 - ptx/src/test/spirv_run/mov_address.spvtxt | 33 - ptx/src/test/spirv_run/mul_ftz.spvtxt | 59 - ptx/src/test/spirv_run/mul_hi.spvtxt | 47 - ptx/src/test/spirv_run/mul_lo.spvtxt | 47 - ptx/src/test/spirv_run/mul_non_ftz.spvtxt | 59 - ptx/src/test/spirv_run/mul_wide.spvtxt | 66 - ptx/src/test/spirv_run/neg.spvtxt | 47 - .../test/spirv_run/non_scalar_ptr_offset.spvtxt | 60 - ptx/src/test/spirv_run/not.spvtxt | 48 - ptx/src/test/spirv_run/ntid.spvtxt | 60 - ptx/src/test/spirv_run/or.spvtxt | 60 - ptx/src/test/spirv_run/popc.spvtxt | 52 - ptx/src/test/spirv_run/pred_not.spvtxt | 82 - ptx/src/test/spirv_run/prmt.spvtxt | 67 - ptx/src/test/spirv_run/rcp.spvtxt | 48 - ptx/src/test/spirv_run/reg_local.spvtxt | 76 - ptx/src/test/spirv_run/rem.spvtxt | 59 - ptx/src/test/spirv_run/rsqrt.spvtxt | 48 - ptx/src/test/spirv_run/selp.spvtxt | 61 - ptx/src/test/spirv_run/selp_true.spvtxt | 61 - ptx/src/test/spirv_run/setp.spvtxt | 77 - ptx/src/test/spirv_run/setp_gt.spvtxt | 79 - ptx/src/test/spirv_run/setp_leu.spvtxt | 79 - ptx/src/test/spirv_run/setp_nan.spvtxt | 228 - ptx/src/test/spirv_run/setp_num.spvtxt | 240 - ptx/src/test/spirv_run/shared_ptr_32.spvtxt | 73 - .../test/spirv_run/shared_ptr_take_address.spvtxt | 64 - ptx/src/test/spirv_run/shared_unify_extern.spvtxt | 118 - ptx/src/test/spirv_run/shared_unify_local.spvtxt | 117 - ptx/src/test/spirv_run/shared_variable.spvtxt | 61 - ptx/src/test/spirv_run/shl.spvtxt | 51 - ptx/src/test/spirv_run/shl_link_hack.ptx | 30 - ptx/src/test/spirv_run/shl_link_hack.spvtxt | 65 - ptx/src/test/spirv_run/shr.spvtxt | 48 - ptx/src/test/spirv_run/sign_extend.spvtxt | 47 - ptx/src/test/spirv_run/sin.spvtxt | 48 - ptx/src/test/spirv_run/sqrt.spvtxt | 48 - ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt | 93 - .../spirv_run/stateful_ld_st_ntid_chain.spvtxt | 97 - .../test/spirv_run/stateful_ld_st_ntid_sub.spvtxt | 107 - .../test/spirv_run/stateful_ld_st_simple.spvtxt | 65 - ptx/src/test/spirv_run/stateful_neg_offset.spvtxt | 80 - ptx/src/test/spirv_run/sub.spvtxt | 47 - ptx/src/test/spirv_run/vector.spvtxt | 99 - ptx/src/test/spirv_run/vector4.spvtxt | 56 - ptx/src/test/spirv_run/vector_extract.spvtxt | 125 - ptx/src/test/spirv_run/xor.spvtxt | 59 - ptx/src/translate.rs | 8181 -------------------- ptx_parser/src/ast.rs | 53 +- ptx_parser/src/lib.rs | 24 +- 138 files changed, 3159 insertions(+), 25763 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json delete mode 100644 ptx/build.rs delete mode 100644 ptx/lib/zluda_ptx_impl.cl create mode 100644 ptx/lib/zluda_ptx_impl.cpp delete mode 100644 ptx/lib/zluda_ptx_impl.spv delete mode 100644 ptx/src/ast.rs delete mode 100644 ptx/src/pass/convert_dynamic_shared_memory_usage.rs delete mode 100644 ptx/src/pass/convert_to_stateful_memory_access.rs delete mode 100644 ptx/src/pass/convert_to_typed.rs delete mode 100644 ptx/src/pass/emit_spirv.rs delete mode 100644 ptx/src/pass/expand_arguments.rs delete mode 100644 ptx/src/pass/extract_globals.rs delete mode 100644 ptx/src/pass/fix_special_registers.rs delete mode 100644 ptx/src/pass/insert_implicit_conversions.rs delete mode 100644 ptx/src/pass/insert_mem_ssa_statements.rs delete mode 100644 ptx/src/pass/normalize_identifiers.rs delete mode 100644 ptx/src/pass/normalize_labels.rs delete mode 100644 ptx/src/pass/normalize_predicates.rs create mode 100644 ptx/src/pass/replace_instructions_with_function_calls.rs delete mode 100644 ptx/src/ptx.lalrpop delete mode 100644 ptx/src/test/spirv_run/activemask.spvtxt delete mode 100644 ptx/src/test/spirv_run/add.spvtxt delete mode 100644 ptx/src/test/spirv_run/add_non_coherent.spvtxt delete mode 100644 ptx/src/test/spirv_run/add_tuning.spvtxt delete mode 100644 ptx/src/test/spirv_run/and.spvtxt delete mode 100644 ptx/src/test/spirv_run/assertfail.spvtxt delete mode 100644 ptx/src/test/spirv_run/atom_add.spvtxt delete mode 100644 ptx/src/test/spirv_run/atom_add_float.spvtxt delete mode 100644 ptx/src/test/spirv_run/atom_cas.spvtxt delete mode 100644 ptx/src/test/spirv_run/atom_inc.spvtxt delete mode 100644 ptx/src/test/spirv_run/b64tof64.spvtxt delete mode 100644 ptx/src/test/spirv_run/bfe.spvtxt delete mode 100644 ptx/src/test/spirv_run/bfi.spvtxt delete mode 100644 ptx/src/test/spirv_run/block.spvtxt delete mode 100644 ptx/src/test/spirv_run/bra.spvtxt delete mode 100644 ptx/src/test/spirv_run/brev.spvtxt delete mode 100644 ptx/src/test/spirv_run/call.spvtxt delete mode 100644 ptx/src/test/spirv_run/clz.spvtxt delete mode 100644 ptx/src/test/spirv_run/const.spvtxt delete mode 100644 ptx/src/test/spirv_run/constant_f32.spvtxt delete mode 100644 ptx/src/test/spirv_run/constant_negative.spvtxt delete mode 100644 ptx/src/test/spirv_run/cos.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_f64_f32.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_rni.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_rzi.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_s16_s8.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_s32_f32.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_s64_s32.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt delete mode 100644 ptx/src/test/spirv_run/cvta.spvtxt delete mode 100644 ptx/src/test/spirv_run/div_approx.spvtxt delete mode 100644 ptx/src/test/spirv_run/ex2.spvtxt delete mode 100644 ptx/src/test/spirv_run/extern_func.spvtxt delete mode 100644 ptx/src/test/spirv_run/extern_shared.spvtxt delete mode 100644 ptx/src/test/spirv_run/extern_shared_call.spvtxt delete mode 100644 ptx/src/test/spirv_run/fma.spvtxt delete mode 100644 ptx/src/test/spirv_run/func_ptr.spvtxt delete mode 100644 ptx/src/test/spirv_run/global_array.spvtxt delete mode 100644 ptx/src/test/spirv_run/implicit_param.spvtxt delete mode 100644 ptx/src/test/spirv_run/lanemask_lt.spvtxt delete mode 100644 ptx/src/test/spirv_run/ld_st.spvtxt delete mode 100644 ptx/src/test/spirv_run/ld_st_implicit.spvtxt delete mode 100644 ptx/src/test/spirv_run/ld_st_offset.spvtxt delete mode 100644 ptx/src/test/spirv_run/lg2.spvtxt delete mode 100644 ptx/src/test/spirv_run/local_align.spvtxt delete mode 100644 ptx/src/test/spirv_run/mad_s32.spvtxt delete mode 100644 ptx/src/test/spirv_run/max.spvtxt delete mode 100644 ptx/src/test/spirv_run/membar.spvtxt delete mode 100644 ptx/src/test/spirv_run/min.spvtxt delete mode 100644 ptx/src/test/spirv_run/mov.spvtxt delete mode 100644 ptx/src/test/spirv_run/mov_address.spvtxt delete mode 100644 ptx/src/test/spirv_run/mul_ftz.spvtxt delete mode 100644 ptx/src/test/spirv_run/mul_hi.spvtxt delete mode 100644 ptx/src/test/spirv_run/mul_lo.spvtxt delete mode 100644 ptx/src/test/spirv_run/mul_non_ftz.spvtxt delete mode 100644 ptx/src/test/spirv_run/mul_wide.spvtxt delete mode 100644 ptx/src/test/spirv_run/neg.spvtxt delete mode 100644 ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt delete mode 100644 ptx/src/test/spirv_run/not.spvtxt delete mode 100644 ptx/src/test/spirv_run/ntid.spvtxt delete mode 100644 ptx/src/test/spirv_run/or.spvtxt delete mode 100644 ptx/src/test/spirv_run/popc.spvtxt delete mode 100644 ptx/src/test/spirv_run/pred_not.spvtxt delete mode 100644 ptx/src/test/spirv_run/prmt.spvtxt delete mode 100644 ptx/src/test/spirv_run/rcp.spvtxt delete mode 100644 ptx/src/test/spirv_run/reg_local.spvtxt delete mode 100644 ptx/src/test/spirv_run/rem.spvtxt delete mode 100644 ptx/src/test/spirv_run/rsqrt.spvtxt delete mode 100644 ptx/src/test/spirv_run/selp.spvtxt delete mode 100644 ptx/src/test/spirv_run/selp_true.spvtxt delete mode 100644 ptx/src/test/spirv_run/setp.spvtxt delete mode 100644 ptx/src/test/spirv_run/setp_gt.spvtxt delete mode 100644 ptx/src/test/spirv_run/setp_leu.spvtxt delete mode 100644 ptx/src/test/spirv_run/setp_nan.spvtxt delete mode 100644 ptx/src/test/spirv_run/setp_num.spvtxt delete mode 100644 ptx/src/test/spirv_run/shared_ptr_32.spvtxt delete mode 100644 ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt delete mode 100644 ptx/src/test/spirv_run/shared_unify_extern.spvtxt delete mode 100644 ptx/src/test/spirv_run/shared_unify_local.spvtxt delete mode 100644 ptx/src/test/spirv_run/shared_variable.spvtxt delete mode 100644 ptx/src/test/spirv_run/shl.spvtxt delete mode 100644 ptx/src/test/spirv_run/shl_link_hack.ptx delete mode 100644 ptx/src/test/spirv_run/shl_link_hack.spvtxt delete mode 100644 ptx/src/test/spirv_run/shr.spvtxt delete mode 100644 ptx/src/test/spirv_run/sign_extend.spvtxt delete mode 100644 ptx/src/test/spirv_run/sin.spvtxt delete mode 100644 ptx/src/test/spirv_run/sqrt.spvtxt delete mode 100644 ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt delete mode 100644 ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt delete mode 100644 ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt delete mode 100644 ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt delete mode 100644 ptx/src/test/spirv_run/stateful_neg_offset.spvtxt delete mode 100644 ptx/src/test/spirv_run/sub.spvtxt delete mode 100644 ptx/src/test/spirv_run/vector.spvtxt delete mode 100644 ptx/src/test/spirv_run/vector4.spvtxt delete mode 100644 ptx/src/test/spirv_run/vector_extract.spvtxt delete mode 100644 ptx/src/test/spirv_run/xor.spvtxt delete mode 100644 ptx/src/translate.rs diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..3df6b99 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,37 @@ +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 + +RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + wget \ + build-essential \ + cmake \ + ninja-build \ + python3 \ + ripgrep \ + git \ + ltrace + +# Feel free to change to a newer version if you have a newer verison on your host +ARG CUDA_VERSION=12-4 +# Docker <-> host driver version compatiblity is newer host <-> older docker +# We don't care about a specific driver version, so pick oldest 5XX +ARG CUDA_DRIVER=515 +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + nvidia-utils-${CUDA_DRIVER} \ + cuda-cudart-${CUDA_VERSION} + +ARG ROCM_VERSION=6.2.2 +RUN mkdir --parents --mode=0755 /etc/apt/keyrings && \ + wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \ + gpg --dearmor | tee /etc/apt/keyrings/rocm.gpg > /dev/null && \ + echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${ROCM_VERSION} jammy main" > /etc/apt/sources.list.d/rocm.list && \ + echo 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' > /etc/apt/preferences.d/rocm-pin-600 && \ + DEBIAN_FRONTEND=noninteractive apt update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + rocminfo \ + rocm-gdb \ + rocm-smi-lib \ + hip-runtime-amd && \ + echo '/opt/rocm/lib' > /etc/ld.so.conf.d/rocm.conf && \ + ldconfig + +ENV PATH=$PATH:/opt/rocm-6.2.2/bin + diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..7cae35a --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,34 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/rust +{ + "name": "zluda", + "build": { + "dockerfile": "Dockerfile" + }, + "securityOpt": [ "seccomp=unconfined" ], + "runArgs": [ + "--runtime=nvidia", + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add=video" + ], + "mounts": [ + { + "source": "${localEnv:HOME}/.cargo/", + "target": "/root/.cargo", + "type": "bind" + } + ], + // https://containers.dev/features. + "features": { + "ghcr.io/devcontainers/features/rust:1": {} + }, + // https://aka.ms/dev-containers-non-root. + "remoteUser": "root", + //"hostRequirements": { "gpu": "optional" } + "customizations": { + "vscode": { + "extensions": [ "mhutchie.git-graph" ], + } +} +} diff --git a/Cargo.toml b/Cargo.toml index 93585a1..8a7467a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,3 @@ members = [ ] default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"] - -[patch.crates-io] -rspirv = { git = 'https://github.com/vosen/rspirv', rev = '9826e59a232c4a426482cda12f88d11bfda3ff9c' } -spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '9826e59a232c4a426482cda12f88d11bfda3ff9c' } diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index f27a127..129dc14 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -1,5 +1,5 @@ use amd_comgr_sys::*; -use std::{ffi::CStr, mem, ptr}; +use std::{ffi::CStr, iter, mem, ptr}; struct Data(amd_comgr_data_t); @@ -79,6 +79,24 @@ impl ActionInfo { unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) } } + fn set_language(&self, language: amd_comgr_language_t) -> Result<(), amd_comgr_status_s> { + unsafe { amd_comgr_action_info_set_language(self.get(), language) } + } + + fn set_options<'a>( + &self, + options: impl Iterator, + ) -> Result<(), amd_comgr_status_s> { + let options = options.map(|x| x.as_ptr()).collect::>(); + unsafe { + amd_comgr_action_info_set_option_list( + self.get(), + options.as_ptr().cast_mut(), + options.len(), + ) + } + } + fn get(&self) -> amd_comgr_action_info_t { self.0 } @@ -90,36 +108,62 @@ impl Drop for ActionInfo { } } -pub fn compile_bitcode(gcn_arch: &CStr, buffer: &[u8]) -> Result, amd_comgr_status_s> { +pub fn compile_bitcode( + gcn_arch: &CStr, + main_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, amd_comgr_status_s> { use amd_comgr_sys::*; let bitcode_data_set = DataSet::new()?; - let bitcode_data = Data::new( + let main_bitcode_data = Data::new( amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, c"zluda.bc", - buffer, + main_buffer, + )?; + bitcode_data_set.add(&main_bitcode_data)?; + let stdlib_bitcode_data = Data::new( + amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, + c"ptx_impl.bc", + ptx_impl, + )?; + bitcode_data_set.add(&stdlib_bitcode_data)?; + let lang_action_info = ActionInfo::new()?; + lang_action_info.set_isa_name(gcn_arch)?; + lang_action_info.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?; + let with_device_libs = do_action( + &bitcode_data_set, + &lang_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, + )?; + let linked_data_set = do_action( + &with_device_libs, + &lang_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_BC_TO_BC, + )?; + let compile_action_info = ActionInfo::new()?; + compile_action_info.set_isa_name(gcn_arch)?; + compile_action_info.set_options(iter::once(c"-O3"))?; + let reloc_data_set = do_action( + &linked_data_set, + &compile_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, + )?; + let exec_data_set = do_action( + &reloc_data_set, + &compile_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, )?; - bitcode_data_set.add(&bitcode_data)?; - let reloc_data_set = DataSet::new()?; - let action_info = ActionInfo::new()?; - action_info.set_isa_name(gcn_arch)?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, - action_info.get(), - bitcode_data_set.get(), - reloc_data_set.get(), - ) - }?; - let exec_data_set = DataSet::new()?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, - action_info.get(), - reloc_data_set.get(), - exec_data_set.get(), - ) - }?; let executable = exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?; executable.copy_content() } + +fn do_action( + data_set: &DataSet, + action: &ActionInfo, + kind: amd_comgr_action_kind_t, +) -> Result { + let result = DataSet::new()?; + unsafe { amd_comgr_do_action(kind, action.get(), data_set.get(), result.get()) }?; + Ok(result) +} diff --git a/llvm_zluda/src/lib.cpp b/llvm_zluda/src/lib.cpp index 3da88fb..886aa0d 100644 --- a/llvm_zluda/src/lib.cpp +++ b/llvm_zluda/src/lib.cpp @@ -1,6 +1,144 @@ #include #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +typedef enum +{ + LLVMZludaAtomicRMWBinOpXchg, /**< Set the new value and return the one old */ + LLVMZludaAtomicRMWBinOpAdd, /**< Add a value and return the old one */ + LLVMZludaAtomicRMWBinOpSub, /**< Subtract a value and return the old one */ + LLVMZludaAtomicRMWBinOpAnd, /**< And a value and return the old one */ + LLVMZludaAtomicRMWBinOpNand, /**< Not-And a value and return the old one */ + LLVMZludaAtomicRMWBinOpOr, /**< OR a value and return the old one */ + LLVMZludaAtomicRMWBinOpXor, /**< Xor a value and return the old one */ + LLVMZludaAtomicRMWBinOpMax, /**< Sets the value if it's greater than the + original using a signed comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpMin, /**< Sets the value if it's Smaller than the + original using a signed comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpUMax, /**< Sets the value if it's greater than the + original using an unsigned comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpUMin, /**< Sets the value if it's greater than the + original using an unsigned comparison and return + the old one */ + LLVMZludaAtomicRMWBinOpFAdd, /**< Add a floating point value and return the + old one */ + LLVMZludaAtomicRMWBinOpFSub, /**< Subtract a floating point value and return the + old one */ + LLVMZludaAtomicRMWBinOpFMax, /**< Sets the value if it's greater than the + original using an floating point comparison and + return the old one */ + LLVMZludaAtomicRMWBinOpFMin, /**< Sets the value if it's smaller than the + original using an floating point comparison and + return the old one */ + LLVMZludaAtomicRMWBinOpUIncWrap, /**< Increments the value, wrapping back to zero + when incremented above input value */ + LLVMZludaAtomicRMWBinOpUDecWrap, /**< Decrements the value, wrapping back to + the input value when decremented below zero */ +} LLVMZludaAtomicRMWBinOp; + +static llvm::AtomicRMWInst::BinOp mapFromLLVMRMWBinOp(LLVMZludaAtomicRMWBinOp BinOp) +{ + switch (BinOp) + { + case LLVMZludaAtomicRMWBinOpXchg: + return llvm::AtomicRMWInst::Xchg; + case LLVMZludaAtomicRMWBinOpAdd: + return llvm::AtomicRMWInst::Add; + case LLVMZludaAtomicRMWBinOpSub: + return llvm::AtomicRMWInst::Sub; + case LLVMZludaAtomicRMWBinOpAnd: + return llvm::AtomicRMWInst::And; + case LLVMZludaAtomicRMWBinOpNand: + return llvm::AtomicRMWInst::Nand; + case LLVMZludaAtomicRMWBinOpOr: + return llvm::AtomicRMWInst::Or; + case LLVMZludaAtomicRMWBinOpXor: + return llvm::AtomicRMWInst::Xor; + case LLVMZludaAtomicRMWBinOpMax: + return llvm::AtomicRMWInst::Max; + case LLVMZludaAtomicRMWBinOpMin: + return llvm::AtomicRMWInst::Min; + case LLVMZludaAtomicRMWBinOpUMax: + return llvm::AtomicRMWInst::UMax; + case LLVMZludaAtomicRMWBinOpUMin: + return llvm::AtomicRMWInst::UMin; + case LLVMZludaAtomicRMWBinOpFAdd: + return llvm::AtomicRMWInst::FAdd; + case LLVMZludaAtomicRMWBinOpFSub: + return llvm::AtomicRMWInst::FSub; + case LLVMZludaAtomicRMWBinOpFMax: + return llvm::AtomicRMWInst::FMax; + case LLVMZludaAtomicRMWBinOpFMin: + return llvm::AtomicRMWInst::FMin; + case LLVMZludaAtomicRMWBinOpUIncWrap: + return llvm::AtomicRMWInst::UIncWrap; + case LLVMZludaAtomicRMWBinOpUDecWrap: + return llvm::AtomicRMWInst::UDecWrap; + } + + llvm_unreachable("Invalid LLVMZludaAtomicRMWBinOp value!"); +} + +static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) +{ + switch (Ordering) + { + case LLVMAtomicOrderingNotAtomic: + return AtomicOrdering::NotAtomic; + case LLVMAtomicOrderingUnordered: + return AtomicOrdering::Unordered; + case LLVMAtomicOrderingMonotonic: + return AtomicOrdering::Monotonic; + case LLVMAtomicOrderingAcquire: + return AtomicOrdering::Acquire; + case LLVMAtomicOrderingRelease: + return AtomicOrdering::Release; + case LLVMAtomicOrderingAcquireRelease: + return AtomicOrdering::AcquireRelease; + case LLVMAtomicOrderingSequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; + } + + llvm_unreachable("Invalid LLVMAtomicOrdering value!"); +} + +typedef unsigned LLVMFastMathFlags; + +enum +{ + LLVMFastMathAllowReassoc = (1 << 0), + LLVMFastMathNoNaNs = (1 << 1), + LLVMFastMathNoInfs = (1 << 2), + LLVMFastMathNoSignedZeros = (1 << 3), + LLVMFastMathAllowReciprocal = (1 << 4), + LLVMFastMathAllowContract = (1 << 5), + LLVMFastMathApproxFunc = (1 << 6), + LLVMFastMathNone = 0, + LLVMFastMathAll = LLVMFastMathAllowReassoc | LLVMFastMathNoNaNs | + LLVMFastMathNoInfs | LLVMFastMathNoSignedZeros | + LLVMFastMathAllowReciprocal | LLVMFastMathAllowContract | + LLVMFastMathApproxFunc, +}; + +static FastMathFlags mapFromLLVMFastMathFlags(LLVMFastMathFlags FMF) +{ + FastMathFlags NewFMF; + NewFMF.setAllowReassoc((FMF & LLVMFastMathAllowReassoc) != 0); + NewFMF.setNoNaNs((FMF & LLVMFastMathNoNaNs) != 0); + NewFMF.setNoInfs((FMF & LLVMFastMathNoInfs) != 0); + NewFMF.setNoSignedZeros((FMF & LLVMFastMathNoSignedZeros) != 0); + NewFMF.setAllowReciprocal((FMF & LLVMFastMathAllowReciprocal) != 0); + NewFMF.setAllowContract((FMF & LLVMFastMathAllowContract) != 0); + NewFMF.setApproxFunc((FMF & LLVMFastMathApproxFunc) != 0); + + return NewFMF; +} LLVM_C_EXTERN_C_BEGIN @@ -10,4 +148,48 @@ LLVMValueRef LLVMZludaBuildAlloca(LLVMBuilderRef B, LLVMTypeRef Ty, unsigned Add return llvm::wrap(llvm::unwrap(B)->CreateAlloca(llvm::unwrap(Ty), AddrSpace, nullptr, Name)); } +LLVMValueRef LLVMZludaBuildAtomicRMW(LLVMBuilderRef B, LLVMZludaAtomicRMWBinOp op, + LLVMValueRef PTR, LLVMValueRef Val, + char *scope, + LLVMAtomicOrdering ordering) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + llvm::AtomicRMWInst::BinOp intop = mapFromLLVMRMWBinOp(op); + return llvm::wrap(builder->CreateAtomicRMW( + intop, llvm::unwrap(PTR), llvm::unwrap(Val), llvm::MaybeAlign(), + mapFromLLVMOrdering(ordering), + context.getOrInsertSyncScopeID(scope))); +} + +LLVMValueRef LLVMZludaBuildAtomicCmpXchg(LLVMBuilderRef B, LLVMValueRef Ptr, + LLVMValueRef Cmp, LLVMValueRef New, + char *scope, + LLVMAtomicOrdering SuccessOrdering, + LLVMAtomicOrdering FailureOrdering) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + return wrap(builder->CreateAtomicCmpXchg( + unwrap(Ptr), unwrap(Cmp), unwrap(New), MaybeAlign(), + mapFromLLVMOrdering(SuccessOrdering), + mapFromLLVMOrdering(FailureOrdering), + context.getOrInsertSyncScopeID(scope))); +} + +void LLVMZludaSetFastMathFlags(LLVMValueRef FPMathInst, LLVMFastMathFlags FMF) +{ + Value *P = unwrap(FPMathInst); + cast(P)->setFastMathFlags(mapFromLLVMFastMathFlags(FMF)); +} + +void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering, + char *scope, const char *Name) +{ + auto builder = llvm::unwrap(B); + LLVMContext &context = builder->getContext(); + builder->CreateFence(mapFromLLVMOrdering(Ordering), + context.getOrInsertSyncScopeID(scope)); +} + LLVM_C_EXTERN_C_END \ No newline at end of file diff --git a/llvm_zluda/src/lib.rs b/llvm_zluda/src/lib.rs index 18072a8..fb5cc47 100644 --- a/llvm_zluda/src/lib.rs +++ b/llvm_zluda/src/lib.rs @@ -1,5 +1,48 @@ +#![allow(non_upper_case_globals)] use llvm_sys::prelude::*; pub use llvm_sys::*; + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LLVMZludaAtomicRMWBinOp { + LLVMZludaAtomicRMWBinOpXchg = 0, + LLVMZludaAtomicRMWBinOpAdd = 1, + LLVMZludaAtomicRMWBinOpSub = 2, + LLVMZludaAtomicRMWBinOpAnd = 3, + LLVMZludaAtomicRMWBinOpNand = 4, + LLVMZludaAtomicRMWBinOpOr = 5, + LLVMZludaAtomicRMWBinOpXor = 6, + LLVMZludaAtomicRMWBinOpMax = 7, + LLVMZludaAtomicRMWBinOpMin = 8, + LLVMZludaAtomicRMWBinOpUMax = 9, + LLVMZludaAtomicRMWBinOpUMin = 10, + LLVMZludaAtomicRMWBinOpFAdd = 11, + LLVMZludaAtomicRMWBinOpFSub = 12, + LLVMZludaAtomicRMWBinOpFMax = 13, + LLVMZludaAtomicRMWBinOpFMin = 14, + LLVMZludaAtomicRMWBinOpUIncWrap = 15, + LLVMZludaAtomicRMWBinOpUDecWrap = 16, +} + +// Backport from LLVM 19 +pub const LLVMZludaFastMathAllowReassoc: ::std::ffi::c_uint = 1 << 0; +pub const LLVMZludaFastMathNoNaNs: ::std::ffi::c_uint = 1 << 1; +pub const LLVMZludaFastMathNoInfs: ::std::ffi::c_uint = 1 << 2; +pub const LLVMZludaFastMathNoSignedZeros: ::std::ffi::c_uint = 1 << 3; +pub const LLVMZludaFastMathAllowReciprocal: ::std::ffi::c_uint = 1 << 4; +pub const LLVMZludaFastMathAllowContract: ::std::ffi::c_uint = 1 << 5; +pub const LLVMZludaFastMathApproxFunc: ::std::ffi::c_uint = 1 << 6; +pub const LLVMZludaFastMathNone: ::std::ffi::c_uint = 0; +pub const LLVMZludaFastMathAll: ::std::ffi::c_uint = LLVMZludaFastMathAllowReassoc + | LLVMZludaFastMathNoNaNs + | LLVMZludaFastMathNoInfs + | LLVMZludaFastMathNoSignedZeros + | LLVMZludaFastMathAllowReciprocal + | LLVMZludaFastMathAllowContract + | LLVMZludaFastMathApproxFunc; + +pub type LLVMZludaFastMathFlags = std::ffi::c_uint; + extern "C" { pub fn LLVMZludaBuildAlloca( B: LLVMBuilderRef, @@ -7,4 +50,32 @@ extern "C" { AddrSpace: u32, Name: *const i8, ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicRMW( + B: LLVMBuilderRef, + op: LLVMZludaAtomicRMWBinOp, + PTR: LLVMValueRef, + Val: LLVMValueRef, + scope: *const i8, + ordering: LLVMAtomicOrdering, + ) -> LLVMValueRef; + + pub fn LLVMZludaBuildAtomicCmpXchg( + B: LLVMBuilderRef, + Ptr: LLVMValueRef, + Cmp: LLVMValueRef, + New: LLVMValueRef, + scope: *const i8, + SuccessOrdering: LLVMAtomicOrdering, + FailureOrdering: LLVMAtomicOrdering, + ) -> LLVMValueRef; + + pub fn LLVMZludaSetFastMathFlags(FPMathInst: LLVMValueRef, FMF: LLVMZludaFastMathFlags); + + pub fn LLVMZludaBuildFence( + B: LLVMBuilderRef, + ordering: LLVMAtomicOrdering, + scope: *const i8, + Name: *const i8, + ) -> LLVMValueRef; } diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index e2c4ff8..9f3fa02 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -9,9 +9,6 @@ edition = "2021" [dependencies] ptx_parser = { path = "../ptx_parser" } llvm_zluda = { path = "../llvm_zluda" } -regex = "1" -rspirv = "0.7" -spirv_headers = "1.5" quick-error = "1.2" thiserror = "1.0" bit-vec = "0.6" @@ -21,18 +18,9 @@ rustc-hash = "2.0.0" strum = "0.26" strum_macros = "0.26" -[dependencies.lalrpop-util] -version = "0.19.12" -features = ["lexer"] - -[build-dependencies.lalrpop] -version = "0.19.12" -features = ["lexer"] - [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } comgr = { path = "../comgr" } -spirv_tools-sys = { path = "../spirv_tools-sys" } tempfile = "3" paste = "1.0" cuda-driver-sys = "0.3.0" diff --git a/ptx/build.rs b/ptx/build.rs deleted file mode 100644 index 42c5d59..0000000 --- a/ptx/build.rs +++ /dev/null @@ -1,5 +0,0 @@ -extern crate lalrpop; - -fn main() { - lalrpop::process_root().unwrap(); -} \ No newline at end of file diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2d194c4..6651430 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl deleted file mode 100644 index 86bb593..0000000 --- a/ptx/lib/zluda_ptx_impl.cl +++ /dev/null @@ -1,344 +0,0 @@ -// Every time this file changes it must te rebuilt: -// ocloc -file zluda_ptx_impl.cl -64 -options "-cl-std=CL2.0 -Dcl_intel_bit_instructions -DINTEL" -out_dir . -device kbl -output_no_suffix -spv_only -// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x cl -Xclang -finclude-default-header zluda_ptx_impl.cl -cl-std=CL2.0 -c -target amdgcn-amd-amdhsa -o zluda_ptx_impl.bc -emit-llvm -// Additionally you should strip names: -// spirv-opt --strip-debug zluda_ptx_impl.spv -o zluda_ptx_impl.spv --target-env=spv1.3 - -#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable -#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable - -#define FUNC(NAME) __zluda_ptx_impl__ ## NAME - -#define atomic_inc(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \ - uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \ - uint expected = *ptr; \ - uint desired; \ - do { \ - desired = (expected >= threshold) ? 0 : expected + 1; \ - } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \ - return expected; \ - } - -#define atomic_dec(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \ - uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \ - uint expected = *ptr; \ - uint desired; \ - do { \ - desired = (expected == 0 || expected > threshold) ? threshold : expected - 1; \ - } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \ - return expected; \ - } - -#define atomic_add(NAME, SUCCESS, FAILURE, SCOPE, SPACE, TYPE, ATOMIC_TYPE, INT_TYPE) \ - TYPE FUNC(NAME)(SPACE TYPE* ptr, TYPE value) { \ - volatile SPACE ATOMIC_TYPE* atomic_ptr = (volatile SPACE ATOMIC_TYPE*)ptr; \ - union { \ - INT_TYPE int_view; \ - TYPE float_view; \ - } expected, desired; \ - expected.float_view = *ptr; \ - do { \ - desired.float_view = expected.float_view + value; \ - } while (!atomic_compare_exchange_strong_explicit(atomic_ptr, &expected.int_view, desired.int_view, SUCCESS, FAILURE, SCOPE)); \ - return expected.float_view; \ - } - -// We are doing all this mess instead of accepting memory_order and memory_scope parameters -// because ocloc emits broken (failing spirv-dis) SPIR-V when memory_order or memory_scope is a parameter - -// atom.inc -atomic_inc(atom_relaxed_cta_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, ); -atomic_inc(atom_acquire_cta_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, ); -atomic_inc(atom_release_cta_generic_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, ); -atomic_inc(atom_acq_rel_cta_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, ); - -atomic_inc(atom_relaxed_gpu_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); -atomic_inc(atom_acquire_gpu_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, ); -atomic_inc(atom_release_gpu_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, ); -atomic_inc(atom_acq_rel_gpu_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); - -atomic_inc(atom_relaxed_sys_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); -atomic_inc(atom_acquire_sys_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, ); -atomic_inc(atom_release_sys_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, ); -atomic_inc(atom_acq_rel_sys_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); - -atomic_inc(atom_relaxed_cta_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global); -atomic_inc(atom_acquire_cta_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global); -atomic_inc(atom_release_cta_global_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __global); -atomic_inc(atom_acq_rel_cta_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global); - -atomic_inc(atom_relaxed_gpu_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); -atomic_inc(atom_acquire_gpu_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); -atomic_inc(atom_release_gpu_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global); -atomic_inc(atom_acq_rel_gpu_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); - -atomic_inc(atom_relaxed_sys_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); -atomic_inc(atom_acquire_sys_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); -atomic_inc(atom_release_sys_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global); -atomic_inc(atom_acq_rel_sys_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); - -atomic_inc(atom_relaxed_cta_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local); -atomic_inc(atom_acquire_cta_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local); -atomic_inc(atom_release_cta_shared_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __local); -atomic_inc(atom_acq_rel_cta_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local); - -atomic_inc(atom_relaxed_gpu_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); -atomic_inc(atom_acquire_gpu_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); -atomic_inc(atom_release_gpu_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local); -atomic_inc(atom_acq_rel_gpu_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); - -atomic_inc(atom_relaxed_sys_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); -atomic_inc(atom_acquire_sys_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); -atomic_inc(atom_release_sys_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local); -atomic_inc(atom_acq_rel_sys_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); - -// atom.dec -atomic_dec(atom_relaxed_cta_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, ); -atomic_dec(atom_acquire_cta_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, ); -atomic_dec(atom_release_cta_generic_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, ); -atomic_dec(atom_acq_rel_cta_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, ); - -atomic_dec(atom_relaxed_gpu_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); -atomic_dec(atom_acquire_gpu_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, ); -atomic_dec(atom_release_gpu_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, ); -atomic_dec(atom_acq_rel_gpu_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); - -atomic_dec(atom_relaxed_sys_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, ); -atomic_dec(atom_acquire_sys_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, ); -atomic_dec(atom_release_sys_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, ); -atomic_dec(atom_acq_rel_sys_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, ); - -atomic_dec(atom_relaxed_cta_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global); -atomic_dec(atom_acquire_cta_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global); -atomic_dec(atom_release_cta_global_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __global); -atomic_dec(atom_acq_rel_cta_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global); - -atomic_dec(atom_relaxed_gpu_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); -atomic_dec(atom_acquire_gpu_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); -atomic_dec(atom_release_gpu_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global); -atomic_dec(atom_acq_rel_gpu_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); - -atomic_dec(atom_relaxed_sys_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global); -atomic_dec(atom_acquire_sys_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global); -atomic_dec(atom_release_sys_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global); -atomic_dec(atom_acq_rel_sys_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global); - -atomic_dec(atom_relaxed_cta_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local); -atomic_dec(atom_acquire_cta_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local); -atomic_dec(atom_release_cta_shared_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __local); -atomic_dec(atom_acq_rel_cta_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local); - -atomic_dec(atom_relaxed_gpu_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local); -atomic_dec(atom_acquire_gpu_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); -atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); - -// atom.add.f32 -atomic_add(atom_relaxed_cta_generic_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, , float, atomic_uint, uint); -atomic_add(atom_acquire_cta_generic_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_work_group, , float, atomic_uint, uint); -atomic_add(atom_release_cta_generic_add_f32, memory_order_release, memory_order_acquire, memory_scope_work_group, , float, atomic_uint, uint); -atomic_add(atom_acq_rel_cta_generic_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, , float, atomic_uint, uint); - -atomic_add(atom_relaxed_gpu_generic_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_acquire_gpu_generic_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_release_gpu_generic_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_acq_rel_gpu_generic_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); - -atomic_add(atom_relaxed_sys_generic_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_acquire_sys_generic_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_release_sys_generic_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); -atomic_add(atom_acq_rel_sys_generic_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, , float, atomic_uint, uint); - -atomic_add(atom_relaxed_cta_global_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global, float, atomic_uint, uint); -atomic_add(atom_acquire_cta_global_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global, float, atomic_uint, uint); -atomic_add(atom_release_cta_global_add_f32, memory_order_release, memory_order_acquire, memory_scope_work_group, __global, float, atomic_uint, uint); -atomic_add(atom_acq_rel_cta_global_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global, float, atomic_uint, uint); - -atomic_add(atom_relaxed_gpu_global_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_acquire_gpu_global_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_release_gpu_global_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_acq_rel_gpu_global_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); - -atomic_add(atom_relaxed_sys_global_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_acquire_sys_global_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_release_sys_global_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); -atomic_add(atom_acq_rel_sys_global_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global, float, atomic_uint, uint); - -atomic_add(atom_relaxed_cta_shared_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local, float, atomic_uint, uint); -atomic_add(atom_acquire_cta_shared_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local, float, atomic_uint, uint); -atomic_add(atom_release_cta_shared_add_f32, memory_order_release, memory_order_acquire, memory_scope_work_group, __local, float, atomic_uint, uint); -atomic_add(atom_acq_rel_cta_shared_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local, float, atomic_uint, uint); - -atomic_add(atom_relaxed_gpu_shared_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_acquire_gpu_shared_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_release_gpu_shared_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_acq_rel_gpu_shared_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); - -atomic_add(atom_relaxed_sys_shared_add_f32, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_acquire_sys_shared_add_f32, memory_order_acquire, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_release_sys_shared_add_f32, memory_order_release, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); -atomic_add(atom_acq_rel_sys_shared_add_f32, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local, float, atomic_uint, uint); - -atomic_add(atom_relaxed_cta_generic_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, , double, atomic_ulong, ulong); -atomic_add(atom_acquire_cta_generic_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_work_group, , double, atomic_ulong, ulong); -atomic_add(atom_release_cta_generic_add_f64, memory_order_release, memory_order_acquire, memory_scope_work_group, , double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_cta_generic_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, , double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_gpu_generic_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_acquire_gpu_generic_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_release_gpu_generic_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_gpu_generic_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_sys_generic_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_acquire_sys_generic_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_release_sys_generic_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_sys_generic_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, , double, atomic_ulong, ulong); -// atom.add.f64 -atomic_add(atom_relaxed_cta_global_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global, double, atomic_ulong, ulong); -atomic_add(atom_acquire_cta_global_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global, double, atomic_ulong, ulong); -atomic_add(atom_release_cta_global_add_f64, memory_order_release, memory_order_acquire, memory_scope_work_group, __global, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_cta_global_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global, double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_gpu_global_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_acquire_gpu_global_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_release_gpu_global_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_gpu_global_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_sys_global_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_acquire_sys_global_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_release_sys_global_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_sys_global_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global, double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_cta_shared_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local, double, atomic_ulong, ulong); -atomic_add(atom_acquire_cta_shared_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local, double, atomic_ulong, ulong); -atomic_add(atom_release_cta_shared_add_f64, memory_order_release, memory_order_acquire, memory_scope_work_group, __local, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_cta_shared_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local, double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_gpu_shared_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_acquire_gpu_shared_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_release_gpu_shared_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_gpu_shared_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); - -atomic_add(atom_relaxed_sys_shared_add_f64, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_acquire_sys_shared_add_f64, memory_order_acquire, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_release_sys_shared_add_f64, memory_order_release, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); -atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local, double, atomic_ulong, ulong); - -#ifdef INTEL - uint FUNC(bfe_u32)(uint base, uint pos, uint len) { - return intel_ubfe(base, pos, len); - } - - ulong FUNC(bfe_u64)(ulong base, uint pos, uint len) { - return intel_ubfe(base, pos, len); - } - - int FUNC(bfe_s32)(int base, uint pos, uint len) { - return intel_sbfe(base, pos, len); - } - - long FUNC(bfe_s64)(long base, uint pos, uint len) { - return intel_sbfe(base, pos, len); - } - - uint FUNC(bfi_b32)(uint insert, uint base, uint offset, uint count) { - return intel_bfi(base, insert, offset, count); - } - - ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) { - return intel_bfi(base, insert, offset, count); - } - - uint FUNC(brev_b32)(uint base) { - return intel_bfrev(base); - } - - ulong FUNC(brev_b64)(ulong base) { - return intel_bfrev(base); - } -#else - uint FUNC(bfe_u32)(uint base, uint pos, uint len) { - return amd_bfe(base, pos, len); - } - - ulong FUNC(bfe_u64)(ulong base, uint pos, uint len) { - return (base >> pos) & len; - } - - int FUNC(bfe_s32)(int base, uint pos, uint len) { - return amd_bfe(base, pos, len); - } - - long FUNC(bfe_s64)(long base, uint pos, uint len) { - return (base >> pos) & len; - } - - uint FUNC(bfi_b32)(uint insert, uint base, uint offset, uint count) { - uint mask = amd_bfm(count, offset); - return (~mask & base) | (mask & insert); - } - - ulong FUNC(bfi_b64)(ulong insert, ulong base, uint offset, uint count) { - ulong mask = ((1UL << (count & 0x3f)) - 1UL) << (offset & 0x3f); - return (~mask & base) | (mask & insert); - } - - extern __attribute__((const)) uint __llvm_bitreverse_i32(uint) __asm("llvm.bitreverse.i32"); - uint FUNC(brev_b32)(uint base) { - return __llvm_bitreverse_i32(base); - } - - extern __attribute__((const)) ulong __llvm_bitreverse_i64(ulong) __asm("llvm.bitreverse.i64"); - ulong FUNC(brev_b64)(ulong base) { - return __llvm_bitreverse_i64(base); - } - - // Taken from __ballot definition in hipamd/include/hip/amd_detail/amd_device_functions.h - uint FUNC(activemask)() { - return (uint)__builtin_amdgcn_uicmp(1, 0, 33); - } - - uint FUNC(sreg_tid)(uchar dim) { - return (uint)get_local_id(dim); - } - - uint FUNC(sreg_ntid)(uchar dim) { - return (uint)get_local_size(dim); - } - - uint FUNC(sreg_ctaid)(uchar dim) { - return (uint)get_group_id(dim); - } - - uint FUNC(sreg_nctaid)(uchar dim) { - return (uint)get_num_groups(dim); - } - - uint FUNC(sreg_clock)() { - return (uint)__builtin_amdgcn_s_memtime(); - } - - // Taken from __ballot definition in hipamd/include/hip/amd_detail/amd_device_functions.h - // They return active threads, which I think is incorrect - extern __attribute__((const)) uint __ockl_lane_u32(); - uint FUNC(sreg_lanemask_lt)() { - uint lane_idx = __ockl_lane_u32(); - ulong mask = (1UL << lane_idx) - 1UL; - return (uint)mask; - } -#endif - -void FUNC(__assertfail)( - __attribute__((unused)) __private ulong* message, - __attribute__((unused)) __private ulong* file, - __attribute__((unused)) __private uint* line, - __attribute__((unused)) __private ulong* function, - __attribute__((unused)) __private ulong* charSize -) { -} - -uint FUNC(vprintf)( - __attribute__((unused)) __generic void* format, - __attribute__((unused)) __generic void* valist -) { - return 0; -} diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp new file mode 100644 index 0000000..f1b416d --- /dev/null +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -0,0 +1,151 @@ +// Every time this file changes it must te rebuilt, you need llvm-17: +// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && llvm-dis-17 zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | llvm-as-17 - -o zluda_ptx_impl.bc && llvm-dis-17 zluda_ptx_impl.bc + +#include +#include + +#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME + +extern "C" +{ + uint32_t FUNC(activemask)() + { + return __builtin_amdgcn_read_exec_lo(); + } + + size_t __ockl_get_local_id(uint32_t) __device__; + uint32_t FUNC(sreg_tid)(uint8_t member) + { + return (uint32_t)__ockl_get_local_id(member); + } + + size_t __ockl_get_local_size(uint32_t) __device__; + uint32_t FUNC(sreg_ntid)(uint8_t member) + { + return (uint32_t)__ockl_get_local_size(member); + } + + size_t __ockl_get_global_id(uint32_t) __device__; + uint32_t FUNC(sreg_ctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_id(member); + } + + size_t __ockl_get_global_size(uint32_t) __device__; + uint32_t FUNC(sreg_nctaid)(uint8_t member) + { + return (uint32_t)__ockl_get_global_size(member); + } + + uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device)); + uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) + { + uint32_t pos = pos_32 & 0xFFU; + uint32_t len = len_32 & 0xFFU; + if (pos >= 32) + return 0; + // V_BFE_U32 only uses bits [4:0] for len (max value is 31) + if (len >= 32) + return base >> pos; + len = std::min(len, 31U); + return __ockl_bfe_u32(base, pos, len); + } + + // LLVM contains mentions of llvm.amdgcn.ubfe.i64 and llvm.amdgcn.sbfe.i64, + // but using it only leads to LLVM crashes on RDNA2 + uint64_t FUNC(bfe_u64)(uint64_t base, uint32_t pos, uint32_t len) + { + // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` + // parameters use whole 32 bit number and not just bottom 8 bits + if (pos >= 64) + return 0; + if (len >= 64) + return base >> pos; + len = std::min(len, 63U); + return (base >> pos) & ((1UL << len) - 1UL); + } + + int32_t __ockl_bfe_i32(int32_t, uint32_t, uint32_t) __attribute__((device)); + int32_t FUNC(bfe_s32)(int32_t base, uint32_t pos_32, uint32_t len_32) + { + uint32_t pos = pos_32 & 0xFFU; + uint32_t len = len_32 & 0xFFU; + if (len == 0) + return 0; + if (pos >= 32) + return (base >> 31); + // V_BFE_I32 only uses bits [4:0] for len (max value is 31) + if (len >= 32) + return base >> pos; + len = std::min(len, 31U); + return __ockl_bfe_i32(base, pos, len); + } + + static __device__ uint32_t add_sat(uint32_t x, uint32_t y) + { + uint32_t result; + if (__builtin_add_overflow(x, y, &result)) + { + return UINT32_MAX; + } + else + { + return result; + } + } + + static __device__ uint32_t sub_sat(uint32_t x, uint32_t y) + { + uint32_t result; + if (__builtin_sub_overflow(x, y, &result)) + { + return 0; + } + else + { + return result; + } + } + + int64_t FUNC(bfe_s64)(int64_t base, uint32_t pos, uint32_t len) + { + // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` + // parameters use whole 32 bit number and not just bottom 8 bits + if (len == 0) + return 0; + if (pos >= 64) + return (base >> 63U); + if (add_sat(pos, len) >= 64) + len = sub_sat(64, pos); + return (base << (64U - pos - len)) >> (64U - len); + } + + uint32_t __ockl_bfm_u32(uint32_t count, uint32_t offset) __attribute__((device)); + uint32_t FUNC(bfi_b32)(uint32_t insert, uint32_t base, uint32_t pos_32, uint32_t len_32) + { + uint32_t pos = pos_32 & 0xFFU; + uint32_t len = len_32 & 0xFFU; + if (pos >= 32) + return base; + uint32_t mask; + if (len >= 32) + mask = UINT32_MAX << pos; + else + mask = __ockl_bfm_u32(len, pos); + return (~mask & base) | (mask & (insert << pos)); + } + + uint64_t FUNC(bfi_b64)(uint64_t insert, uint64_t base, uint32_t pos, uint32_t len) + { + // NVIDIA docs are incorrect. In 64 bit `bfe` both `pos` and `len` + // parameters use whole 32 bit number and not just bottom 8 bits + if (pos >= 64) + return base; + uint64_t mask; + if (len >= 64) + mask = UINT64_MAX << pos; + else + mask = ((1UL << len) - 1UL) << (pos); + return (~mask & base) | (mask & (insert << pos)); + } +} diff --git a/ptx/lib/zluda_ptx_impl.spv b/ptx/lib/zluda_ptx_impl.spv deleted file mode 100644 index e9fc938..0000000 Binary files a/ptx/lib/zluda_ptx_impl.spv and /dev/null differ diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs deleted file mode 100644 index 358b8ce..0000000 --- a/ptx/src/ast.rs +++ /dev/null @@ -1,1074 +0,0 @@ -use half::f16; -use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; -use std::{marker::PhantomData, num::ParseIntError}; - -#[derive(Debug, thiserror::Error)] -pub enum PtxError { - #[error("{source}")] - ParseInt { - #[from] - source: ParseIntError, - }, - #[error("{source}")] - ParseFloat { - #[from] - source: ParseFloatError, - }, - #[error("")] - Unsupported32Bit, - #[error("")] - SyntaxError, - #[error("")] - NonF32Ftz, - #[error("")] - WrongArrayType, - #[error("")] - WrongVectorElement, - #[error("")] - MultiArrayVariable, - #[error("")] - ZeroDimensionArray, - #[error("")] - ArrayInitalizer, - #[error("")] - NonExternPointer, - #[error("{start}:{end}")] - UnrecognizedStatement { start: usize, end: usize }, - #[error("{start}:{end}")] - UnrecognizedDirective { start: usize, end: usize }, -} - -// For some weird reson this is illegal: -// .param .f16x2 foobar; -// but this is legal: -// .param .f16x2 foobar[1]; -// even more interestingly this is legal, but only in .func (not in .entry): -// .param .b32 foobar[] - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum BarDetails { - SyncAligned, -} - -pub trait UnwrapWithVec { - fn unwrap_with(self, errs: &mut Vec) -> To; -} - -impl, EInto> UnwrapWithVec - for Result -{ - fn unwrap_with(self, errs: &mut Vec) -> R { - self.unwrap_or_else(|e| { - errs.push(e.into()); - R::default() - }) - } -} - -impl< - R1: Default, - EFrom1: std::convert::Into, - R2: Default, - EFrom2: std::convert::Into, - EInto, - > UnwrapWithVec for (Result, Result) -{ - fn unwrap_with(self, errs: &mut Vec) -> (R1, R2) { - let (x, y) = self; - let r1 = x.unwrap_with(errs); - let r2 = y.unwrap_with(errs); - (r1, r2) - } -} - -pub struct Module<'a> { - pub version: (u8, u8), - pub directives: Vec>>, -} - -pub enum Directive<'a, P: ArgParams> { - Variable(LinkingDirective, Variable), - Method(LinkingDirective, Function<'a, &'a str, Statement

>), -} - -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -pub enum MethodName<'input, ID> { - Kernel(&'input str), - Func(ID), -} - -pub struct MethodDeclaration<'input, ID> { - pub return_arguments: Vec>, - pub name: MethodName<'input, ID>, - pub input_arguments: Vec>, - pub shared_mem: Option, -} - -pub struct Function<'a, ID, S> { - pub func_directive: MethodDeclaration<'a, ID>, - pub tuning: Vec, - pub body: Option>, -} - -pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; - -#[derive(PartialEq, Eq, Clone)] -pub enum Type { - // .param.b32 foo; - // -> OpTypeInt - Scalar(ScalarType), - // .param.v2.b32 foo; - // -> OpTypeVector - Vector(ScalarType, u8), - // .param.b32 foo[4]; - // -> OpTypeArray - Array(ScalarType, Vec), - /* - Variables of this type almost never exist in the original .ptx and are - usually artificially created. Some examples below: - - extern pointers to the .shared memory in the form: - .extern .shared .b32 shared_mem[]; - which we first parse as - .extern .shared .b32 shared_mem; - and then convert to an additional function parameter: - .param .ptr<.b32.shared> shared_mem; - and do a load at the start of the function (and renames inside fn): - .reg .ptr<.b32.shared> temp; - ld.param.ptr<.b32.shared> temp, [shared_mem]; - note, we don't support non-.shared extern pointers, because there's - zero use for them in the ptxas - - artifical pointers created by stateful conversion, which work - similiarly to the above - - function parameters: - foobar(.param .align 4 .b8 numbers[]) - which get parsed to - foobar(.param .align 4 .b8 numbers) - and then converted to - foobar(.reg .align 4 .ptr<.b8.param> numbers) - - ld/st with offset: - .reg.b32 x; - .param.b64 arg0; - st.param.b32 [arg0+4], x; - Yes, this code is legal and actually emitted by the NV compiler! - We convert the st to: - .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); - st.param.b32 [temp], x; - */ - // .reg ptr<.b64.param> - // -> OpTypePointer Function - Pointer(ScalarType, StateSpace), -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum ScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, - F16x2, - Pred, -} - -impl ScalarType { - pub fn size_of(self) -> u8 { - match self { - ScalarType::U8 => 1, - ScalarType::S8 => 1, - ScalarType::B8 => 1, - ScalarType::U16 => 2, - ScalarType::S16 => 2, - ScalarType::B16 => 2, - ScalarType::F16 => 2, - ScalarType::U32 => 4, - ScalarType::S32 => 4, - ScalarType::B32 => 4, - ScalarType::F32 => 4, - ScalarType::U64 => 8, - ScalarType::S64 => 8, - ScalarType::B64 => 8, - ScalarType::F64 => 8, - ScalarType::F16x2 => 4, - ScalarType::Pred => 1, - } - } -} - -impl Default for ScalarType { - fn default() -> Self { - ScalarType::B8 - } -} - -pub enum Statement { - Label(P::Id), - Variable(MultiVariable), - Instruction(Option>, Instruction

), - Block(Vec>), -} - -pub struct MultiVariable { - pub var: Variable, - pub count: Option, -} - -#[derive(Clone)] -pub struct Variable { - pub align: Option, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, - Generic, - Sreg, -} - -pub struct PredAt { - pub not: bool, - pub label: ID, -} - -pub enum Instruction { - Ld(LdDetails, Arg2Ld

), - Mov(MovDetails, Arg2Mov

), - Mul(MulDetails, Arg3

), - Add(ArithDetails, Arg3

), - Setp(SetpData, Arg4Setp

), - SetpBool(SetpBoolData, Arg5Setp

), - Not(ScalarType, Arg2

), - Bra(BraData, Arg1

), - Cvt(CvtDetails, Arg2

), - Cvta(CvtaDetails, Arg2

), - Shl(ScalarType, Arg3

), - Shr(ScalarType, Arg3

), - St(StData, Arg2St

), - Ret(RetData), - Call(CallInst

), - Abs(AbsDetails, Arg2

), - Mad(MulDetails, Arg4

), - Fma(ArithFloat, Arg4

), - Or(ScalarType, Arg3

), - Sub(ArithDetails, Arg3

), - Min(MinMaxDetails, Arg3

), - Max(MinMaxDetails, Arg3

), - Rcp(RcpDetails, Arg2

), - And(ScalarType, Arg3

), - Selp(ScalarType, Arg4

), - Bar(BarDetails, Arg1Bar

), - Atom(AtomDetails, Arg3

), - AtomCas(AtomCasDetails, Arg4

), - Div(DivDetails, Arg3

), - Sqrt(SqrtDetails, Arg2

), - Rsqrt(RsqrtDetails, Arg2

), - Neg(NegDetails, Arg2

), - Sin { flush_to_zero: bool, arg: Arg2

}, - Cos { flush_to_zero: bool, arg: Arg2

}, - Lg2 { flush_to_zero: bool, arg: Arg2

}, - Ex2 { flush_to_zero: bool, arg: Arg2

}, - Clz { typ: ScalarType, arg: Arg2

}, - Brev { typ: ScalarType, arg: Arg2

}, - Popc { typ: ScalarType, arg: Arg2

}, - Xor { typ: ScalarType, arg: Arg3

}, - Bfe { typ: ScalarType, arg: Arg4

}, - Bfi { typ: ScalarType, arg: Arg5

}, - Rem { typ: ScalarType, arg: Arg3

}, - Prmt { control: u16, arg: Arg3

}, - Activemask { arg: Arg1

}, - Membar { level: MemScope }, -} - -#[derive(Copy, Clone)] -pub struct MadFloatDesc {} - -#[derive(Copy, Clone)] -pub struct AbsDetails { - pub flush_to_zero: Option, - pub typ: ScalarType, -} -#[derive(Copy, Clone)] -pub struct RcpDetails { - pub rounding: Option, - pub flush_to_zero: Option, - pub is_f64: bool, -} - -pub struct CallInst { - pub uniform: bool, - pub ret_params: Vec, - pub func: P::Id, - pub param_list: Vec, -} - -pub trait ArgParams { - type Id; - type Operand; -} - -pub struct ParsedArgParams<'a> { - _marker: PhantomData<&'a ()>, -} - -impl<'a> ArgParams for ParsedArgParams<'a> { - type Id = &'a str; - type Operand = Operand<&'a str>; -} - -pub struct Arg1 { - pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand -} - -pub struct Arg1Bar { - pub src: P::Operand, -} - -pub struct Arg2 { - pub dst: P::Operand, - pub src: P::Operand, -} -pub struct Arg2Ld { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg2St { - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg2Mov { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg3 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg4 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -pub struct Arg4Setp { - pub dst1: P::Id, - pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg5 { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, - pub src4: P::Operand, -} - -pub struct Arg5Setp { - pub dst1: P::Id, - pub dst2: Option, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -#[derive(Copy, Clone)] -pub enum ImmediateValue { - U64(u64), - S64(i64), - F32(f32), - F64(f64), -} - -#[derive(Clone)] -pub enum Operand { - Reg(Id), - RegOffset(Id, i32), - Imm(ImmediateValue), - VecMember(Id, u8), - VecPack(Vec), -} - -pub enum VectorPrefix { - V2, - V4, -} - -pub struct LdDetails { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MemScope { - Cta, - Gpu, - Sys, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - pub fn new(typ: Type) -> Self { - MovDetails { - typ, - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -#[derive(Copy, Clone)] -pub struct MulIntDesc { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MulIntControl { - Low, - High, - Wide, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -pub struct AddIntDesc { - pub typ: ScalarType, - pub saturate: bool, -} - -pub struct SetpData { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub cmp_op: SetpCompareOp, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareOp { - Eq, - NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, - NanEq, - NanNotEq, - NanLess, - NanLessOrEq, - NanGreater, - NanGreaterOrEq, - IsNotNan, - IsAnyNan, -} - -pub enum SetpBoolPostOp { - And, - Or, - Xor, -} - -pub struct SetpBoolData { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub cmp_op: SetpCompareOp, - pub bool_op: SetpBoolPostOp, -} - -pub struct BraData { - pub uniform: bool, -} - -pub enum CvtDetails { - IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), -} - -pub struct CvtIntToIntDesc { - pub dst: ScalarType, - pub src: ScalarType, - pub saturate: bool, -} - -pub struct CvtDesc { - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, - pub dst: ScalarType, - pub src: ScalarType, -} - -impl CvtDetails { - pub fn new_int_from_int_checked<'err, 'input>( - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if saturate { - if src.kind() == ScalarKind::Signed { - if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::User { - error: PtxError::SyntaxError, - }); - } - } else { - if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::User { - error: PtxError::SyntaxError, - }); - } - } - } - CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate }) - } - - pub fn new_float_from_int_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(lalrpop_util::ParseError::User { - error: PtxError::NonF32Ftz, - })); - } - CvtDetails::FloatFromInt(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } - - pub fn new_int_from_float_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec, PtxError>>, - ) -> Self { - if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(lalrpop_util::ParseError::User { - error: PtxError::NonF32Ftz, - })); - } - CvtDetails::IntFromFloat(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } -} - -pub struct CvtaDetails { - pub to: StateSpace, - pub from: StateSpace, - pub size: CvtaSize, -} - -pub enum CvtaSize { - U32, - U64, -} - -pub struct StData { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: StCacheOperator, - pub typ: Type, -} - -#[derive(PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - -pub struct RetData { - pub uniform: bool, -} - -#[derive(Copy, Clone)] -pub enum MulDetails { - Unsigned(MulUInt), - Signed(MulSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct MulUInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub struct MulSInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(ScalarType), - Signed(ArithSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: ScalarType, - pub rounding: Option, - pub flush_to_zero: Option, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub enum MinMaxDetails { - Signed(ScalarType), - Unsigned(ScalarType), - Float(MinMaxFloat), -} - -#[derive(Copy, Clone)] -pub struct MinMaxFloat { - pub flush_to_zero: Option, - pub nan: bool, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub struct AtomDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub inner: AtomInnerDetails, -} - -#[derive(Copy, Clone)] -pub enum AtomSemantics { - Relaxed, - Acquire, - Release, - AcquireRelease, -} - -#[derive(Copy, Clone)] -pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: ScalarType }, - Unsigned { op: AtomUIntOp, typ: ScalarType }, - Signed { op: AtomSIntOp, typ: ScalarType }, - Float { op: AtomFloatOp, typ: ScalarType }, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomBitOp { - And, - Or, - Xor, - Exchange, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomUIntOp { - Add, - Inc, - Dec, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomSIntOp { - Add, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomFloatOp { - Add, -} - -#[derive(Copy, Clone)] -pub struct AtomCasDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub enum DivDetails { - Unsigned(ScalarType), - Signed(ScalarType), - Float(DivFloatDetails), -} - -#[derive(Copy, Clone)] -pub struct DivFloatDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub kind: DivFloatKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - Full, - Rounding(RoundingMode), -} - -pub enum NumsOrArrays<'a> { - Nums(Vec<(&'a str, u32)>), - Arrays(Vec>), -} - -#[derive(Copy, Clone)] -pub struct SqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, - pub kind: SqrtKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum SqrtKind { - Approx, - Rounding(RoundingMode), -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct RsqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: bool, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct NegDetails { - pub typ: ScalarType, - pub flush_to_zero: Option, -} - -impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result, PtxError> { - self.normalize_dimensions(dimensions)?; - let sizeof_t = ScalarType::from(typ).size_of() as usize; - let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); - let mut result = vec![0; result_size]; - self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; - Ok(result) - } - - fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { - match dimensions.first_mut() { - Some(first) => { - if *first == 0 { - *first = match self { - NumsOrArrays::Nums(v) => v.len() as u32, - NumsOrArrays::Arrays(v) => v.len() as u32, - }; - } - } - None => return Err(PtxError::ZeroDimensionArray), - } - for dim in dimensions { - if *dim == 0 { - return Err(PtxError::ZeroDimensionArray); - } - } - Ok(()) - } - - fn parse_and_copy( - &self, - t: ScalarType, - size_of_t: usize, - dimensions: &[u32], - result: &mut [u8], - ) -> Result<(), PtxError> { - match dimensions { - [] => unreachable!(), - [dim] => match self { - NumsOrArrays::Nums(vec) => { - if vec.len() > *dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - for (idx, (val, radix)) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, *radix, result)?; - } - } - NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), - }, - [first_dim, rest @ ..] => match self { - NumsOrArrays::Arrays(vec) => { - if vec.len() > *first_dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); - for (idx, this) in vec.iter().enumerate() { - this.parse_and_copy( - t, - size_of_t, - rest, - &mut result[(size_of_element * idx)..], - )?; - } - } - NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), - }, - } - Ok(()) - } - - fn parse_and_copy_single( - t: ScalarType, - idx: usize, - str_val: &str, - radix: u32, - output: &mut [u8], - ) -> Result<(), PtxError> { - match t { - ScalarType::B8 | ScalarType::U8 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B16 | ScalarType::U16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B32 | ScalarType::U32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::B64 | ScalarType::U64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S8 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::S64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F16 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F16x2 => todo!(), - ScalarType::F32 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::F64 => { - Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; - } - ScalarType::Pred => todo!(), - } - Ok(()) - } - - fn parse_and_copy_single_t( - idx: usize, - str_val: &str, - _radix: u32, // TODO: use this to properly support hex literals - output: &mut [u8], - ) -> Result<(), PtxError> - where - T::Err: Into, - { - let typed_output = unsafe { - std::slice::from_raw_parts_mut::( - output.as_mut_ptr() as *mut _, - output.len() / mem::size_of::(), - ) - }; - typed_output[idx] = str_val.parse::().map_err(|e| e.into())?; - Ok(()) - } -} - -pub enum ArrayOrPointer { - Array { dimensions: Vec, init: Vec }, - Pointer, -} - -bitflags! { - pub struct LinkingDirective: u8 { - const NONE = 0b000; - const EXTERN = 0b001; - const VISIBLE = 0b10; - const WEAK = 0b100; - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum TuningDirective { - MaxNReg(u32), - MaxNtid(u32, u32, u32), - ReqNtid(u32, u32, u32), - MinNCtaPerSm(u32), -} - -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - -impl ScalarType { - pub fn kind(self) -> ScalarKind { - match self { - ScalarType::U8 => ScalarKind::Unsigned, - ScalarType::U16 => ScalarKind::Unsigned, - ScalarType::U32 => ScalarKind::Unsigned, - ScalarType::U64 => ScalarKind::Unsigned, - ScalarType::S8 => ScalarKind::Signed, - ScalarType::S16 => ScalarKind::Signed, - ScalarType::S32 => ScalarKind::Signed, - ScalarType::S64 => ScalarKind::Signed, - ScalarType::B8 => ScalarKind::Bit, - ScalarType::B16 => ScalarKind::Bit, - ScalarType::B32 => ScalarKind::Bit, - ScalarType::B64 => ScalarKind::Bit, - ScalarType::F16 => ScalarKind::Float, - ScalarType::F32 => ScalarKind::Float, - ScalarType::F64 => ScalarKind::Float, - ScalarType::F16x2 => ScalarKind::Float2, - ScalarType::Pred => ScalarKind::Pred, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn array_fails_multiple_0_dmiensions() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); - } - - #[test] - fn array_fails_on_empty() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); - } - - #[test] - fn array_auto_sizes_0_dimension() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert_eq!( - vec![1u8, 2, 3, 4], - inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() - ); - assert_eq!(dimensions, vec![2u32, 2]); - } - - #[test] - fn array_fails_wrong_structure() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } - - #[test] - fn array_fails_too_long_component() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), - NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } -} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 5e95dae..da972f6 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -1,186 +1,6 @@ -#[cfg(test)] -extern crate paste; -#[macro_use] -extern crate lalrpop_util; -#[macro_use] -extern crate quick_error; - -extern crate bit_vec; -extern crate half; -#[cfg(test)] -extern crate hip_runtime_sys as hip; -extern crate rspirv; -extern crate spirv_headers as spirv; - -#[cfg(test)] -extern crate spirv_tools_sys as spirv_tools; - -#[macro_use] -extern crate bitflags; - -lalrpop_mod!( - #[allow(warnings)] - ptx -); - -pub mod ast; pub(crate) mod pass; #[cfg(test)] mod test; -mod translate; - -use std::fmt; - -pub use crate::ptx::ModuleParser; -use ast::PtxError; -pub use lalrpop_util::lexer::Token; -pub use lalrpop_util::ParseError; -pub use rspirv::dr::Error as SpirvError; -pub use translate::to_spirv_module; -pub use translate::KernelInfo; -pub use translate::TranslateError; - -pub trait ModuleParserExt { - fn parse_checked<'input>( - txt: &'input str, - ) -> Result, Vec, ast::PtxError>>>; - - // Returned AST might be malformed. Some users, like logger, want to look at - // malformed AST to record information - list of kernels or such - fn parse_unchecked<'input>( - txt: &'input str, - ) -> ( - ast::Module<'input>, - Vec, ast::PtxError>>, - ); -} - -impl ModuleParserExt for ModuleParser { - fn parse_checked<'input>( - txt: &'input str, - ) -> Result, Vec, ast::PtxError>>> { - let mut errors = Vec::new(); - let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt); - match (&*errors, maybe_ast) { - (&[], Ok(ast)) => Ok(ast), - (_, Err(unrecoverable)) => { - errors.push(unrecoverable); - Err(errors) - } - (_, Ok(_)) => Err(errors), - } - } - - fn parse_unchecked<'input>( - txt: &'input str, - ) -> ( - ast::Module<'input>, - Vec, ast::PtxError>>, - ) { - let mut errors = Vec::new(); - let maybe_ast = ptx::ModuleParser::new().parse(&mut errors, txt); - let ast = match maybe_ast { - Ok(ast) => ast, - Err(unrecoverable_err) => { - errors.push(unrecoverable_err); - ast::Module { - version: (0, 0), - directives: Vec::new(), - } - } - }; - (ast, errors) - } -} - -pub struct DisplayParseError<'a, Loc, Tok, Err>(&'a str, &'a ParseError); - -impl<'a, Loc: fmt::Display + Into + Copy, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> { - // unsafe because there's no guarantee that the input str is the one that this error was created from - pub unsafe fn new(error: &'a ParseError, text: &'a str) -> Self { - Self(text, error) - } -} - -impl<'a, Loc, Tok> fmt::Display for DisplayParseError<'a, Loc, Tok, PtxError> -where - Loc: fmt::Display, - Tok: fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.1 { - ParseError::User { - error: PtxError::UnrecognizedStatement { start, end }, - } => self.fmt_unrecognized(f, *start, *end, "statement"), - ParseError::User { - error: PtxError::UnrecognizedDirective { start, end }, - } => self.fmt_unrecognized(f, *start, *end, "directive"), - _ => self.1.fmt(f), - } - } -} - -impl<'a, Loc, Tok, Err> DisplayParseError<'a, Loc, Tok, Err> { - fn fmt_unrecognized( - &self, - f: &mut fmt::Formatter, - start: usize, - end: usize, - kind: &'static str, - ) -> fmt::Result { - let full_substring = unsafe { self.0.get_unchecked(start..end) }; - write!( - f, - "Unrecognized {} `{}` found at {}:{}", - kind, full_substring, start, end - ) - } -} - -pub(crate) fn without_none(x: Vec>) -> Vec { - x.into_iter().filter_map(|x| x).collect() -} - -pub(crate) fn vector_index<'input>( - inp: &'input str, -) -> Result, ast::PtxError>> { - match inp { - "x" | "r" => Ok(0), - "y" | "g" => Ok(1), - "z" | "b" => Ok(2), - "w" | "a" => Ok(3), - _ => Err(ParseError::User { - error: ast::PtxError::WrongVectorElement, - }), - } -} - -#[cfg(test)] -mod tests { - use crate::{DisplayParseError, ModuleParser, ModuleParserExt}; - #[test] - fn error_report_unknown_instructions() { - let module = r#" - .version 6.5 - .target sm_30 - .address_size 64 +pub use pass::to_llvm_module; - .visible .entry add( - .param .u64 input, - ) - { - .reg .u64 x; - does_not_exist.u64 x, x; - ret; - }"#; - let errors = match ModuleParser::parse_checked(module) { - Err(e) => e, - Ok(_) => panic!(), - }; - assert_eq!(errors.len(), 1); - let reporter = DisplayParseError(module, &errors[0]); - let build_log_string = format!("{}", reporter); - assert!(build_log_string.contains("does_not_exist")); - } -} diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs deleted file mode 100644 index 1dac7fd..0000000 --- a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::collections::{BTreeMap, BTreeSet}; - -use super::*; - -/* - PTX represents dynamically allocated shared local memory as - .extern .shared .b32 shared_mem[]; - In SPIRV/OpenCL world this is expressed as an additional argument to the kernel - And in AMD compilation - This pass looks for all uses of .extern .shared and converts them to - an additional method argument - The question is how this artificial argument should be expressed. There are - several options: - * Straight conversion: - .shared .b32 shared_mem[] - * Introduce .param_shared statespace: - .param_shared .b32 shared_mem - or - .param_shared .b32 shared_mem[] - * Introduce .shared_ptr type: - .param .shared_ptr .b32 shared_mem - * Reuse .ptr hint: - .param .u64 .ptr shared_mem - This is the most tempting, but also the most nonsensical, .ptr is just a - hint, which has no semantical meaning (and the output of our - transformation has a semantical meaning - we emit additional - "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") -*/ -pub(super) fn run<'input>( - module: Vec>, - kernels_methods_call_map: &MethodsCallMap<'input>, - new_id: &mut impl FnMut() -> SpirvWord, -) -> Result>, TranslateError> { - let mut globals_shared = HashMap::new(); - for dir in module.iter() { - match dir { - Directive::Variable( - _, - ast::Variable { - state_space: ast::StateSpace::Shared, - name, - v_type, - .. - }, - ) => { - globals_shared.insert(*name, v_type.clone()); - } - _ => {} - } - } - if globals_shared.len() == 0 { - return Ok(module); - } - let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); - let module = module - .into_iter() - .map(|directive| match directive { - Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - }) => { - let call_key = (*func_decl).borrow().name; - let statements = statements - .into_iter() - .map(|statement| { - statement.visit_map( - &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { - if let Some(_) = globals_shared.get(&id) { - methods_to_directly_used_shared_globals - .entry(call_key) - .or_insert_with(HashSet::new) - .insert(id); - } - Ok::<_, TranslateError>(id) - }, - ) - }) - .collect::, _>>()?; - Ok::<_, TranslateError>(Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - })) - } - directive => Ok(directive), - }) - .collect::, _>>()?; - // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, - // make sure it gets propagated to `fn1` and `kernel` - let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( - methods_to_directly_used_shared_globals, - kernels_methods_call_map, - ); - // now visit every method declaration and inject those additional arguments - let mut directives = Vec::with_capacity(module.len()); - for directive in module.into_iter() { - match directive { - Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - }) => { - let statements = { - let func_decl_ref = &mut (*func_decl).borrow_mut(); - let method_name = func_decl_ref.name; - insert_arguments_remap_statements( - new_id, - kernels_methods_call_map, - &globals_shared, - &methods_to_indirectly_used_shared_globals, - method_name, - &mut directives, - func_decl_ref, - statements, - )? - }; - directives.push(Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - })); - } - directive => directives.push(directive), - } - } - Ok(directives) -} - -// We need to compute two kinds of information: -// * If it's a kernel -> size of .shared globals in use (direct or indirect) -// * If it's a function -> does it use .shared global (directly or indirectly) -fn resolve_indirect_uses_of_globals_shared<'input>( - methods_use_of_globals_shared: HashMap, HashSet>, - kernels_methods_call_map: &MethodsCallMap<'input>, -) -> HashMap, BTreeSet> { - let mut result = HashMap::new(); - for (method, callees) in kernels_methods_call_map.methods() { - let mut indirect_globals = methods_use_of_globals_shared - .get(&method) - .into_iter() - .flatten() - .copied() - .collect::>(); - for &callee in callees { - indirect_globals.extend( - methods_use_of_globals_shared - .get(&ast::MethodName::Func(callee)) - .into_iter() - .flatten() - .copied(), - ); - } - result.insert(method, indirect_globals); - } - result -} - -fn insert_arguments_remap_statements<'input>( - new_id: &mut impl FnMut() -> SpirvWord, - kernels_methods_call_map: &MethodsCallMap<'input>, - globals_shared: &HashMap, - methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, SpirvWord>, - BTreeSet, - >, - method_name: ast::MethodName, - result: &mut Vec, - func_decl_ref: &mut std::cell::RefMut>, - statements: Vec, SpirvWord>>, -) -> Result, SpirvWord>>, TranslateError> { - let remapped_globals_in_method = - if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { - match method_name { - ast::MethodName::Func(..) => { - let remapped_globals = method_globals - .iter() - .map(|global| { - ( - *global, - ( - new_id(), - globals_shared - .get(&global) - .unwrap_or_else(|| todo!()) - .clone(), - ), - ) - }) - .collect::>(); - for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { - func_decl_ref.input_arguments.push(ast::Variable { - align: None, - v_type: shared_global_type.clone(), - state_space: ast::StateSpace::Shared, - name: *new_shared_global_id, - array_init: Vec::new(), - }); - } - remapped_globals - } - ast::MethodName::Kernel(..) => method_globals - .iter() - .map(|global| { - ( - *global, - ( - *global, - globals_shared - .get(&global) - .unwrap_or_else(|| todo!()) - .clone(), - ), - ) - }) - .collect::>(), - } - } else { - return Ok(statements); - }; - replace_uses_of_shared_memory( - new_id, - methods_to_indirectly_used_shared_globals, - statements, - remapped_globals_in_method, - ) -} - -fn replace_uses_of_shared_memory<'input>( - new_id: &mut impl FnMut() -> SpirvWord, - methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, SpirvWord>, - BTreeSet, - >, - statements: Vec, - remapped_globals_in_method: BTreeMap, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - match statement { - Statement::Instruction(ast::Instruction::Call { - mut data, - mut arguments, - }) => { - // We can safely skip checking call arguments, - // because there's simply no way to pass shared ptr - // without converting it to .b64 first - if let Some(shared_globals_used_by_callee) = - methods_to_indirectly_used_shared_globals - .get(&ast::MethodName::Func(arguments.func)) - { - for &shared_global_used_by_callee in shared_globals_used_by_callee { - let (remapped_shared_id, type_) = remapped_globals_in_method - .get(&shared_global_used_by_callee) - .unwrap_or_else(|| todo!()); - data.input_arguments - .push((type_.clone(), ast::StateSpace::Shared)); - arguments.input_arguments.push(*remapped_shared_id); - } - } - result.push(Statement::Instruction(ast::Instruction::Call { - data, - arguments, - })) - } - statement => { - let new_statement = - statement.visit_map(&mut |id, - _: Option<(&ast::Type, ast::StateSpace)>, - _, - _| { - Ok::<_, TranslateError>( - if let Some((remapped_shared_id, _)) = - remapped_globals_in_method.get(&id) - { - *remapped_shared_id - } else { - id - }, - ) - })?; - result.push(new_statement); - } - } - } - Ok(result) -} diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs deleted file mode 100644 index 3b8fa93..0000000 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ /dev/null @@ -1,524 +0,0 @@ -use super::*; -use ptx_parser as ast; -use std::{ - collections::{BTreeSet, HashSet}, - iter, - rc::Rc, -}; - -/* - Our goal here is to transform - .visible .entry foobar(.param .u64 input) { - .reg .b64 in_addr; - .reg .b64 in_addr2; - ld.param.u64 in_addr, [input]; - cvta.to.global.u64 in_addr2, in_addr; - } - into: - .visible .entry foobar(.param .u8 input[]) { - .reg .u8 in_addr[]; - .reg .u8 in_addr2[]; - ld.param.u8[] in_addr, [input]; - mov.u8[] in_addr2, in_addr; - } - or: - .visible .entry foobar(.reg .u8 input[]) { - .reg .u8 in_addr[]; - .reg .u8 in_addr2[]; - mov.u8[] in_addr, input; - mov.u8[] in_addr2, in_addr; - } - or: - .visible .entry foobar(.param ptr input) { - .reg ptr in_addr; - .reg ptr in_addr2; - ld.param.ptr in_addr, [input]; - mov.ptr in_addr2, in_addr; - } -*/ -// TODO: detect more patterns (mov, call via reg, call via param) -// TODO: don't convert to ptr if the register is not ultimately used for ld/st -// TODO: once insert_mem_ssa_statements is moved to later, move this pass after -// argument expansion -// TODO: propagate out of calls and into calls -pub(super) fn run<'a, 'input>( - func_args: Rc>>, - func_body: Vec, - id_defs: &mut NumericIdResolver<'a>, -) -> Result< - ( - Rc>>, - Vec, - ), - TranslateError, -> { - let mut method_decl = func_args.borrow_mut(); - if !matches!(method_decl.name, ast::MethodName::Kernel(..)) { - drop(method_decl); - return Ok((func_args, func_body)); - } - if Rc::strong_count(&func_args) != 1 { - return Err(error_unreachable()); - } - let func_args_64bit = (*method_decl) - .input_arguments - .iter() - .filter_map(|arg| match arg.v_type { - ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), - _ => None, - }) - .collect::>(); - let mut stateful_markers = Vec::new(); - let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); - for statement in func_body.iter() { - match statement { - Statement::Instruction(ast::Instruction::Cvta { - data: - ast::CvtaDetails { - state_space: ast::StateSpace::Global, - direction: ast::CvtaDirection::GenericToExplicit, - }, - arguments, - }) => { - if let (TypedOperand::Reg(dst), Some(src)) = - (arguments.dst, arguments.src.underlying_register()) - { - if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) { - stateful_markers.push((dst, src)); - } - } - } - Statement::Instruction(ast::Instruction::Ld { - data: - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::U64), - .. - }, - arguments, - }) - | Statement::Instruction(ast::Instruction::Ld { - data: - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::S64), - .. - }, - arguments, - }) - | Statement::Instruction(ast::Instruction::Ld { - data: - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::B64), - .. - }, - arguments, - }) => { - if let (TypedOperand::Reg(dst), Some(src)) = - (arguments.dst, arguments.src.underlying_register()) - { - if func_args_64bit.contains(&src) { - multi_hash_map_append(&mut stateful_init_reg, dst, src); - } - } - } - _ => {} - } - } - if stateful_markers.len() == 0 { - drop(method_decl); - return Ok((func_args, func_body)); - } - let mut func_args_ptr = HashSet::new(); - let mut regs_ptr_current = HashSet::new(); - for (dst, src) in stateful_markers { - if let Some(func_args) = stateful_init_reg.get(&src) { - for a in func_args { - func_args_ptr.insert(*a); - regs_ptr_current.insert(src); - regs_ptr_current.insert(dst); - } - } - } - // BTreeSet here to have a stable order of iteration, - // unfortunately our tests rely on it - let mut regs_ptr_seen = BTreeSet::new(); - while regs_ptr_current.len() > 0 { - let mut regs_ptr_new = HashSet::new(); - for statement in func_body.iter() { - match statement { - Statement::Instruction(ast::Instruction::Add { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::U64, - saturate: false, - }), - arguments, - }) - | Statement::Instruction(ast::Instruction::Add { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::S64, - saturate: false, - }), - arguments, - }) => { - // TODO: don't mark result of double pointer sub or double - // pointer add as ptr result - if let (TypedOperand::Reg(dst), Some(src1)) = - (arguments.dst, arguments.src1.underlying_register()) - { - if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { - regs_ptr_new.insert(dst); - } - } else if let (TypedOperand::Reg(dst), Some(src2)) = - (arguments.dst, arguments.src2.underlying_register()) - { - if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { - regs_ptr_new.insert(dst); - } - } - } - - Statement::Instruction(ast::Instruction::Sub { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::U64, - saturate: false, - }), - arguments, - }) - | Statement::Instruction(ast::Instruction::Sub { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::S64, - saturate: false, - }), - arguments, - }) => { - // TODO: don't mark result of double pointer sub or double - // pointer add as ptr result - if let (TypedOperand::Reg(dst), Some(src1)) = - (arguments.dst, arguments.src1.underlying_register()) - { - if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { - regs_ptr_new.insert(dst); - } - } else if let (TypedOperand::Reg(dst), Some(src2)) = - (arguments.dst, arguments.src2.underlying_register()) - { - if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { - regs_ptr_new.insert(dst); - } - } - } - _ => {} - } - } - for id in regs_ptr_current { - regs_ptr_seen.insert(id); - } - regs_ptr_current = regs_ptr_new; - } - drop(regs_ptr_current); - let mut remapped_ids = HashMap::new(); - let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); - for reg in regs_ptr_seen { - let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Reg, - ); - result.push(Statement::Variable(ast::Variable { - align: None, - name: new_id, - array_init: Vec::new(), - v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - state_space: ast::StateSpace::Reg, - })); - remapped_ids.insert(reg, new_id); - } - for arg in (*method_decl).input_arguments.iter_mut() { - if !func_args_ptr.contains(&arg.name) { - continue; - } - let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Param, - ); - let old_name = arg.name; - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); - arg.name = new_id; - remapped_ids.insert(old_name, new_id); - } - for statement in func_body { - match statement { - l @ Statement::Label(_) => result.push(l), - c @ Statement::Conditional(_) => result.push(c), - c @ Statement::Constant(..) => result.push(c), - Statement::Variable(var) => { - if !remapped_ids.contains_key(&var.name) { - result.push(Statement::Variable(var)); - } - } - Statement::Instruction(ast::Instruction::Add { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::U64, - saturate: false, - }), - arguments, - }) - | Statement::Instruction(ast::Instruction::Add { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::S64, - saturate: false, - }), - arguments, - }) if is_add_ptr_direct(&remapped_ids, &arguments) => { - let (ptr, offset) = match arguments.src1.underlying_register() { - Some(src1) if remapped_ids.contains_key(&src1) => { - (remapped_ids.get(&src1).unwrap(), arguments.src2) - } - Some(src2) if remapped_ids.contains_key(&src2) => { - (remapped_ids.get(&src2).unwrap(), arguments.src1) - } - _ => return Err(error_unreachable()), - }; - let dst = arguments.dst.unwrap_reg()?; - result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, - dst: *remapped_ids.get(&dst).unwrap(), - ptr_src: *ptr, - offset_src: offset, - })) - } - Statement::Instruction(ast::Instruction::Sub { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::U64, - saturate: false, - }), - arguments, - }) - | Statement::Instruction(ast::Instruction::Sub { - data: - ast::ArithDetails::Integer(ast::ArithInteger { - type_: ast::ScalarType::S64, - saturate: false, - }), - arguments, - }) if is_sub_ptr_direct(&remapped_ids, &arguments) => { - let (ptr, offset) = match arguments.src1.underlying_register() { - Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2), - _ => return Err(error_unreachable()), - }; - let offset_neg = id_defs.register_intermediate(Some(( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ))); - result.push(Statement::Instruction(ast::Instruction::Neg { - data: ast::TypeFtz { - type_: ast::ScalarType::S64, - flush_to_zero: None, - }, - arguments: ast::NegArgs { - src: offset, - dst: TypedOperand::Reg(offset_neg), - }, - })); - let dst = arguments.dst.unwrap_reg()?; - result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, - dst: *remapped_ids.get(&dst).unwrap(), - ptr_src: *ptr, - offset_src: TypedOperand::Reg(offset_neg), - })) - } - inst @ Statement::Instruction(_) => { - let mut post_statements = Vec::new(); - let new_statement = inst.visit_map(&mut FnVisitor::new( - |operand, type_space, is_dst, relaxed_conversion| { - convert_to_stateful_memory_access_postprocess( - id_defs, - &remapped_ids, - &mut result, - &mut post_statements, - operand, - type_space, - is_dst, - relaxed_conversion, - ) - }, - ))?; - result.push(new_statement); - result.extend(post_statements); - } - repack @ Statement::RepackVector(_) => { - let mut post_statements = Vec::new(); - let new_statement = repack.visit_map(&mut FnVisitor::new( - |operand, type_space, is_dst, relaxed_conversion| { - convert_to_stateful_memory_access_postprocess( - id_defs, - &remapped_ids, - &mut result, - &mut post_statements, - operand, - type_space, - is_dst, - relaxed_conversion, - ) - }, - ))?; - result.push(new_statement); - result.extend(post_statements); - } - _ => return Err(error_unreachable()), - } - } - drop(method_decl); - Ok((func_args, result)) -} - -fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { - match id_defs.get_typed(id) { - Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) - | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) - | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, - _ => false, - } -} - -fn is_add_ptr_direct( - remapped_ids: &HashMap, - arg: &ast::AddArgs, -) -> bool { - match arg.dst { - TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { - return false - } - TypedOperand::Reg(dst) => { - if !remapped_ids.contains_key(&dst) { - return false; - } - if let Some(ref src1_reg) = arg.src1.underlying_register() { - if remapped_ids.contains_key(src1_reg) { - // don't trigger optimization when adding two pointers - if let Some(ref src2_reg) = arg.src2.underlying_register() { - return !remapped_ids.contains_key(src2_reg); - } - } - } - if let Some(ref src2_reg) = arg.src2.underlying_register() { - remapped_ids.contains_key(src2_reg) - } else { - false - } - } - } -} - -fn is_sub_ptr_direct( - remapped_ids: &HashMap, - arg: &ast::SubArgs, -) -> bool { - match arg.dst { - TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { - return false - } - TypedOperand::Reg(dst) => { - if !remapped_ids.contains_key(&dst) { - return false; - } - match arg.src1.underlying_register() { - Some(ref src1_reg) => { - if remapped_ids.contains_key(src1_reg) { - // don't trigger optimization when subtracting two pointers - arg.src2 - .underlying_register() - .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg)) - } else { - false - } - } - None => false, - } - } - } -} - -fn convert_to_stateful_memory_access_postprocess( - id_defs: &mut NumericIdResolver, - remapped_ids: &HashMap, - result: &mut Vec, - post_statements: &mut Vec, - operand: TypedOperand, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_conversion: bool, -) -> Result { - operand.map(|operand, _| { - Ok(match remapped_ids.get(&operand) { - Some(new_id) => { - let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; - // TODO: readd if required - if let Some((expected_type, expected_space)) = type_space { - let implicit_conversion = if relaxed_conversion { - if is_dst { - super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper - } else { - super::insert_implicit_conversions::should_convert_relaxed_src_wrapper - } - } else { - super::insert_implicit_conversions::default_implicit_conversion - }; - if implicit_conversion( - (new_operand_space, &new_operand_type), - (expected_space, expected_type), - ) - .is_ok() - { - return Ok(*new_id); - } - } - let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; - let converting_id = id_defs - .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if new_operand_space == ast::StateSpace::Reg { - ConversionKind::Default - } else { - ConversionKind::PtrToPtr - }; - if is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_operand_type, - from_space: old_operand_space, - to_type: new_operand_type, - to_space: new_operand_space, - kind, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: new_operand_type, - from_space: new_operand_space, - to_type: old_operand_type, - to_space: old_operand_space, - kind, - })); - converting_id - } - } - None => operand, - }) - }) -} diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs deleted file mode 100644 index 550c662..0000000 --- a/ptx/src/pass/convert_to_typed.rs +++ /dev/null @@ -1,138 +0,0 @@ -use super::*; -use ptx_parser as ast; - -pub(crate) fn run( - func: Vec, - fn_defs: &GlobalFnDeclResolver, - id_defs: &mut NumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::::with_capacity(func.len()); - for s in func { - match s { - Statement::Instruction(inst) => match inst { - ast::Instruction::Mov { - data, - arguments: - ast::MovArgs { - dst: ast::ParsedOperand::Reg(dst_reg), - src: ast::ParsedOperand::Reg(src_reg), - }, - } if fn_defs.fns.contains_key(&src_reg) => { - if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { - return Err(error_mismatched_type()); - } - result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { - dst: dst_reg, - src: src_reg, - })); - } - ast::Instruction::Call { data, arguments } => { - let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?; - let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?; - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let reresolved_call = - Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?); - visitor.func.push(reresolved_call); - visitor.func.extend(visitor.post_stmts); - } - inst => { - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); - visitor.func.push(instruction); - visitor.func.extend(visitor.post_stmts); - } - }, - Statement::Label(i) => result.push(Statement::Label(i)), - Statement::Variable(v) => result.push(Statement::Variable(v)), - Statement::Conditional(c) => result.push(Statement::Conditional(c)), - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - -struct VectorRepackVisitor<'a, 'b> { - func: &'b mut Vec, - id_def: &'b mut NumericIdResolver<'a>, - post_stmts: Option, -} - -impl<'a, 'b> VectorRepackVisitor<'a, 'b> { - fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { - VectorRepackVisitor { - func, - id_def, - post_stmts: None, - } - } - - fn convert_vector( - &mut self, - is_dst: bool, - relaxed_type_check: bool, - typ: &ast::Type, - state_space: ast::StateSpace, - idx: Vec, - ) -> Result { - // mov.u32 foobar, {a,b}; - let scalar_t = match typ { - ast::Type::Vector(_, scalar_t) => *scalar_t, - _ => return Err(error_mismatched_type()), - }; - let temp_vec = self - .id_def - .register_intermediate(Some((typ.clone(), state_space))); - let statement = Statement::RepackVector(RepackVectorDetails { - is_extract: is_dst, - typ: scalar_t, - packed: temp_vec, - unpacked: idx, - relaxed_type_check, - }); - if is_dst { - self.post_stmts = Some(statement); - } else { - self.func.push(statement); - } - Ok(temp_vec) - } -} - -impl<'a, 'b> ast::VisitorMap, TypedOperand, TranslateError> - for VectorRepackVisitor<'a, 'b> -{ - fn visit_ident( - &mut self, - ident: SpirvWord, - _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - _: bool, - _: bool, - ) -> Result { - Ok(ident) - } - - fn visit( - &mut self, - op: ast::ParsedOperand, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - Ok(match op { - ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), - ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), - ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), - ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::ParsedOperand::VecPack(vec) => { - let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?; - TypedOperand::Reg(self.convert_vector( - is_dst, - relaxed_type_check, - type_, - space, - vec, - )?) - } - }) - } -} diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..15125b0 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap; - use super::*; pub(super) fn run<'a, 'input>( @@ -26,75 +24,73 @@ fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2, mut method: Function2<'input, ast::Instruction, SpirvWord>, ) -> Result, SpirvWord>, TranslateError> { - if method.func_decl.name.is_kernel() { - return Ok(method); - } let is_declaration = method.body.is_none(); let mut body = Vec::new(); let mut remap_returns = Vec::new(); - for arg in method.func_decl.return_arguments.iter_mut() { - match arg.state_space { - ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; - let old_name = arg.name; - arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); - if is_declaration { - continue; + if !method.func_decl.name.is_kernel() { + for arg in method.func_decl.return_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = + resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + remap_returns.push((old_name, arg.name, arg.v_type.clone())); + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); } - remap_returns.push((old_name, arg.name, arg.v_type.clone())); - body.push(Statement::Variable(ast::Variable { - align: None, - name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), - })); + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), } - ptx_parser::StateSpace::Reg => {} - _ => return Err(error_unreachable()), } - } - for arg in method.func_decl.input_arguments.iter_mut() { - match arg.state_space { - ptx_parser::StateSpace::Param => { - arg.state_space = ptx_parser::StateSpace::Reg; - let old_name = arg.name; - arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); - if is_declaration { - continue; + for arg in method.func_decl.input_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = + resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + body.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: arg.v_type.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: arg.name, + }, + })); } - body.push(Statement::Variable(ast::Variable { - align: None, - name: old_name, - v_type: arg.v_type.clone(), - state_space: ptx_parser::StateSpace::Param, - array_init: Vec::new(), - })); - body.push(Statement::Instruction(ast::Instruction::St { - data: ast::StData { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Param, - caching: ast::StCacheOperator::Writethrough, - typ: arg.v_type.clone(), - }, - arguments: ast::StArgs { - src1: old_name, - src2: arg.name, - }, - })); + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), } - ptx_parser::StateSpace::Reg => {} - _ => return Err(error_unreachable()), } } - if remap_returns.is_empty() { - return Ok(method); - } let body = method .body .map(|statements| { for statement in statements { - run_statement(&remap_returns, &mut body, statement)?; + run_statement(resolver, &remap_returns, &mut body, statement)?; } Ok::<_, TranslateError>(body) }) @@ -110,28 +106,89 @@ fn run_method<'input>( } fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, result: &mut Vec, SpirvWord>>, statement: Statement, SpirvWord>, ) -> Result<(), TranslateError> { match statement { - Statement::Instruction(ast::Instruction::Ret { .. }) => { - for (old_name, new_name, type_) in remap_returns.iter().cloned() { + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + let mut post_st = Vec::new(); + for ((type_, space), ident) in data + .input_arguments + .iter_mut() + .zip(arguments.input_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *ident, + src: old_name, + }, + })); + } + } + for ((type_, space), ident) in data + .return_arguments + .iter_mut() + .zip(arguments.return_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + post_st.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: *ident, + }, + })); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + result.extend(post_st.into_iter()); + } + Statement::Instruction(ast::Instruction::Ret { data }) => { + for (old_name, new_name, type_) in remap_returns.iter() { result.push(Statement::Instruction(ast::Instruction::Ld { data: ast::LdDetails { qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Reg, + state_space: ast::StateSpace::Param, caching: ast::LdCacheOperator::Cached, - typ: type_, + typ: type_.clone(), non_coherent: false, }, arguments: ast::LdArgs { - dst: new_name, - src: old_name, + dst: *new_name, + src: *old_name, }, })); } - result.push(statement); + result.push(Statement::Instruction(ast::Instruction::Ret { data })); } statement => { result.push(statement); diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 235ad7d..fa011a3 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -18,16 +18,23 @@ // while with plain LLVM-C it's just: // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; -use std::convert::{TryFrom, TryInto}; -use std::ffi::CStr; +// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete. +// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with +// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all" +// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", +// but it will too fail similarly, but with "unable to legalize instruction" + +use std::array::TryFromSliceError; +use std::convert::TryInto; +use std::ffi::{CStr, NulError}; use std::ops::Deref; -use std::ptr; +use std::{i8, ptr}; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; -use llvm_zluda::core::*; -use llvm_zluda::prelude::*; +use llvm_zluda::{core::*, *}; +use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; const LLVM_UNNAMED: &CStr = c""; @@ -172,7 +179,7 @@ pub(super) fn run<'input>( let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); for directive in directives { match directive { - Directive2::Variable(..) => todo!(), + Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, Directive2::Method(method) => emit_ctx.emit_method(method)?, } } @@ -228,15 +235,18 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }) .ok_or_else(|| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?; - let fn_type = get_function_type( - self.context, - func_decl.return_arguments.iter().map(|v| &v.v_type), - func_decl - .input_arguments - .iter() - .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), - )?; - let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + if fn_ == ptr::null_mut() { + let fn_type = get_function_type( + self.context, + func_decl.return_arguments.iter().map(|v| &v.v_type), + func_decl + .input_arguments + .iter() + .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), + )?; + fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + } if let ast::MethodName::Func(name) = func_decl.name { self.resolver.register(name, fn_); } @@ -274,6 +284,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); + for var in func_decl.return_arguments { + method_emitter.emit_variable(var)?; + } + for statement in statements.iter() { + if let Statement::Label(label) = statement { + method_emitter.emit_label_initial(*label); + } + } for statement in statements { method_emitter.emit_statement(statement)?; } @@ -281,43 +299,146 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } Ok(()) } + + fn emit_global( + &mut self, + _linking: ast::LinkingDirective, + var: ast::Variable, + ) -> Result<(), TranslateError> { + let name = self + .id_defs + .ident_map + .get(&var.name) + .map(|entry| { + entry + .name + .as_ref() + .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?))) + }) + .flatten() + .transpose() + .map_err(|_| error_unreachable())? + .unwrap_or(Cow::Borrowed(LLVM_UNNAMED)); + let global = unsafe { + LLVMAddGlobalInAddressSpace( + self.module, + get_type(self.context, &var.v_type)?, + name.as_ptr(), + get_state_space(var.state_space)?, + ) + }; + self.resolver.register(var.name, global); + if let Some(align) = var.align { + unsafe { LLVMSetAlignment(global, align) }; + } + if !var.array_init.is_empty() { + self.emit_array_init(&var.v_type, &*var.array_init, global)?; + } + Ok(()) + } + + // TODO: instead of Vec we should emit a typed initializer + fn emit_array_init( + &mut self, + type_: &ast::Type, + array_init: &[u8], + global: *mut llvm_zluda::LLVMValue, + ) -> Result<(), TranslateError> { + match type_ { + ast::Type::Array(None, scalar, dimensions) => { + if dimensions.len() != 1 { + todo!() + } + if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { + return Err(error_unreachable()); + } + let type_ = get_scalar_type(self.context, *scalar); + let mut elements = array_init + .chunks(scalar.size_of() as usize) + .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) + .collect::, _>>() + .map_err(|_| error_unreachable())?; + let initializer = + unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; + unsafe { LLVMSetInitializer(global, initializer) }; + } + _ => todo!(), + } + Ok(()) + } + + fn constant_from_bytes( + &self, + scalar: ast::ScalarType, + bytes: &[u8], + llvm_type: LLVMTypeRef, + ) -> Result { + Ok(match scalar { + ptx_parser::ScalarType::Pred + | ptx_parser::ScalarType::S8 + | ptx_parser::ScalarType::B8 + | ptx_parser::ScalarType::U8 => unsafe { + LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::S16 + | ptx_parser::ScalarType::B16 + | ptx_parser::ScalarType::U16 => unsafe { + LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::S32 + | ptx_parser::ScalarType::B32 + | ptx_parser::ScalarType::U32 => unsafe { + LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) + }, + ptx_parser::ScalarType::F16 => todo!(), + ptx_parser::ScalarType::BF16 => todo!(), + ptx_parser::ScalarType::U64 => todo!(), + ptx_parser::ScalarType::S64 => todo!(), + ptx_parser::ScalarType::S16x2 => todo!(), + ptx_parser::ScalarType::F32 => todo!(), + ptx_parser::ScalarType::B64 => todo!(), + ptx_parser::ScalarType::F64 => todo!(), + ptx_parser::ScalarType::B128 => todo!(), + ptx_parser::ScalarType::U16x2 => todo!(), + ptx_parser::ScalarType::F16x2 => todo!(), + ptx_parser::ScalarType::BF16x2 => todo!(), + }) + } } fn get_input_argument_type( context: LLVMContextRef, - v_type: &ptx_parser::Type, - state_space: ptx_parser::StateSpace, + v_type: &ast::Type, + state_space: ast::StateSpace, ) -> Result { match state_space { - ptx_parser::StateSpace::ParamEntry => { + ast::StateSpace::ParamEntry => { Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) } - ptx_parser::StateSpace::Reg => get_type(context, v_type), + ast::StateSpace::Reg => get_type(context, v_type), _ => return Err(error_unreachable()), } } -struct MethodEmitContext<'a, 'input> { +struct MethodEmitContext<'a> { context: LLVMContextRef, module: LLVMModuleRef, method: LLVMValueRef, builder: LLVMBuilderRef, - id_defs: &'a GlobalStringIdentResolver2<'input>, variables_builder: Builder, resolver: &'a mut ResolveIdent, } -impl<'a, 'input> MethodEmitContext<'a, 'input> { - fn new<'x>( - parent: &'a mut ModuleEmitContext<'x, 'input>, +impl<'a> MethodEmitContext<'a> { + fn new( + parent: &'a mut ModuleEmitContext, method: LLVMValueRef, variables_builder: Builder, - ) -> MethodEmitContext<'a, 'input> { + ) -> MethodEmitContext<'a> { MethodEmitContext { context: parent.context, module: parent.module, builder: parent.builder.get(), - id_defs: parent.id_defs, variables_builder, resolver: &mut parent.resolver, method, @@ -330,18 +451,17 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ) -> Result<(), TranslateError> { Ok(match statement { Statement::Variable(var) => self.emit_variable(var)?, - Statement::Label(label) => self.emit_label(label), + Statement::Label(label) => self.emit_label_delayed(label)?, Statement::Instruction(inst) => self.emit_instruction(inst)?, - Statement::Conditional(_) => todo!(), - Statement::LoadVar(var) => self.emit_load_variable(var)?, - Statement::StoreVar(store) => self.emit_store_var(store)?, + Statement::Conditional(cond) => self.emit_conditional(cond)?, Statement::Conversion(conversion) => self.emit_conversion(conversion)?, Statement::Constant(constant) => self.emit_constant(constant)?, - Statement::RetValue(_, _) => todo!(), - Statement::PtrAccess(_) => todo!(), - Statement::RepackVector(_) => todo!(), + Statement::RetValue(_, values) => self.emit_ret_value(values)?, + Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, + Statement::RepackVector(repack) => self.emit_vector_repack(repack)?, Statement::FunctionPointer(_) => todo!(), - Statement::VectorAccess(_) => todo!(), + Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, + Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, }) } @@ -364,7 +484,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_label(&mut self, label: SpirvWord) { + fn emit_label_initial(&mut self, label: SpirvWord) { let block = unsafe { LLVMAppendBasicBlockInContext( self.context, @@ -372,17 +492,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { self.resolver.get_or_add_raw(label), ) }; + self.resolver + .register(label, unsafe { LLVMBasicBlockAsValue(block) }); + } + + fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> { + let block = self.resolver.value(label)?; + let block = unsafe { LLVMValueAsBasicBlock(block) }; let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { unsafe { LLVMBuildBr(self.builder, block) }; } unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; - } - - fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> { - let ptr = self.resolver.value(store.arg.src1)?; - let value = self.resolver.value(store.arg.src2)?; - unsafe { LLVMBuildStore(self.builder, value, ptr) }; Ok(()) } @@ -395,50 +516,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), - ast::Instruction::Mul { data, arguments } => todo!(), - ast::Instruction::Setp { data, arguments } => todo!(), - ast::Instruction::SetpBool { data, arguments } => todo!(), - ast::Instruction::Not { data, arguments } => todo!(), - ast::Instruction::Or { data, arguments } => todo!(), - ast::Instruction::And { data, arguments } => todo!(), - ast::Instruction::Bra { arguments } => todo!(), + ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), + ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), + ast::Instruction::SetpBool { .. } => todo!(), + ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), + ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), + ast::Instruction::And { arguments, .. } => self.emit_and(arguments), + ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), - ast::Instruction::Cvt { data, arguments } => todo!(), - ast::Instruction::Shr { data, arguments } => todo!(), - ast::Instruction::Shl { data, arguments } => todo!(), + ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), + ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), + ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), - ast::Instruction::Cvta { data, arguments } => todo!(), - ast::Instruction::Abs { data, arguments } => todo!(), - ast::Instruction::Mad { data, arguments } => todo!(), - ast::Instruction::Fma { data, arguments } => todo!(), - ast::Instruction::Sub { data, arguments } => todo!(), - ast::Instruction::Min { data, arguments } => todo!(), - ast::Instruction::Max { data, arguments } => todo!(), - ast::Instruction::Rcp { data, arguments } => todo!(), - ast::Instruction::Sqrt { data, arguments } => todo!(), - ast::Instruction::Rsqrt { data, arguments } => todo!(), - ast::Instruction::Selp { data, arguments } => todo!(), - ast::Instruction::Bar { data, arguments } => todo!(), - ast::Instruction::Atom { data, arguments } => todo!(), - ast::Instruction::AtomCas { data, arguments } => todo!(), - ast::Instruction::Div { data, arguments } => todo!(), - ast::Instruction::Neg { data, arguments } => todo!(), - ast::Instruction::Sin { data, arguments } => todo!(), - ast::Instruction::Cos { data, arguments } => todo!(), - ast::Instruction::Lg2 { data, arguments } => todo!(), - ast::Instruction::Ex2 { data, arguments } => todo!(), - ast::Instruction::Clz { data, arguments } => todo!(), - ast::Instruction::Brev { data, arguments } => todo!(), - ast::Instruction::Popc { data, arguments } => todo!(), - ast::Instruction::Xor { data, arguments } => todo!(), - ast::Instruction::Rem { data, arguments } => todo!(), - ast::Instruction::Bfe { data, arguments } => todo!(), - ast::Instruction::Bfi { data, arguments } => todo!(), - ast::Instruction::PrmtSlow { arguments } => todo!(), - ast::Instruction::Prmt { data, arguments } => todo!(), - ast::Instruction::Activemask { arguments } => todo!(), - ast::Instruction::Membar { data } => todo!(), + ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), + ast::Instruction::Abs { .. } => todo!(), + ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments), + ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments), + ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), + ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments), + ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments), + ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), + ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), + ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), + ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments), + ast::Instruction::Bar { .. } => todo!(), + ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), + ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), + ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), + ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), + ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), + ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), + ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments), + ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments), + ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), + ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), + ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), + ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), + ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), + ast::Instruction::PrmtSlow { .. } => todo!(), + ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), + ast::Instruction::Membar { data } => self.emit_membar(data), ast::Instruction::Trap {} => todo!(), + // replaced by a function call + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } } @@ -447,9 +569,6 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { data: ast::LdDetails, arguments: ast::LdArgs, ) -> Result<(), TranslateError> { - if data.non_coherent { - todo!() - } if data.qualifier != ast::LdStQualifier::Weak { todo!() } @@ -462,24 +581,25 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> { - if var.member_index.is_some() { - todo!() - } - let builder = self.builder; - let type_ = get_type(self.context, &var.typ)?; - let ptr = self.resolver.value(var.arg.src)?; - self.resolver.with_result(var.arg.dst, |dst| unsafe { - LLVMBuildLoad2(builder, type_, ptr, dst) - }); - Ok(()) - } - fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { let builder = self.builder; match conversion.kind { - ConversionKind::Default => todo!(), - ConversionKind::SignExtend => todo!(), + ConversionKind::Default => self.emit_conversion_default( + self.resolver.value(conversion.src)?, + conversion.dst, + &conversion.from_type, + conversion.from_space, + &conversion.to_type, + conversion.to_space, + ), + ConversionKind::SignExtend => { + let src = self.resolver.value(conversion.src)?; + let type_ = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildSExt(builder, src, type_, dst) + }); + Ok(()) + } ConversionKind::BitToPtr => { let src = self.resolver.value(conversion.src)?; let type_ = get_pointer_type(self.context, conversion.to_space)?; @@ -488,8 +608,131 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }); Ok(()) } - ConversionKind::PtrToPtr => todo!(), - ConversionKind::AddressOf => todo!(), + ConversionKind::PtrToPtr => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_pointer_type(self.context, conversion.to_space)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(builder, src, dst_type, dst) + }); + Ok(()) + } + ConversionKind::AddressOf => { + let src = self.resolver.value(conversion.src)?; + let dst_type = get_type(self.context, &conversion.to_type)?; + self.resolver.with_result(conversion.dst, |dst| unsafe { + LLVMBuildPtrToInt(self.builder, src, dst_type, dst) + }); + Ok(()) + } + } + } + + fn emit_conversion_default( + &mut self, + src: LLVMValueRef, + dst: SpirvWord, + from_type: &ast::Type, + from_space: ast::StateSpace, + to_type: &ast::Type, + to_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + match (from_type, to_type) { + (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => { + let from_layout = from_type.layout(); + let to_layout = to_type.layout(); + if from_layout.size() == to_layout.size() { + let dst_type = get_type(self.context, &to_type)?; + if from_type.kind() != ast::ScalarKind::Float + && to_type_scalar.kind() != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + self.resolver.register(dst, src); + } else { + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + } + Ok(()) + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = unsafe { + LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) + }; + let same_width_bit_value = unsafe { + LLVMBuildBitCast( + self.builder, + src, + same_width_bit_type, + LLVM_UNNAMED.as_ptr(), + ) + }; + let wide_bit_type = match to_type_scalar.layout().size() { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => return Err(error_unreachable()), + }; + let wide_bit_type_llvm = unsafe { + LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) + }; + if to_type_scalar.kind() == ast::ScalarKind::Unsigned + || to_type_scalar.kind() == ast::ScalarKind::Bit + { + let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + self.resolver.with_result(dst, |dst| unsafe { + llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst) + }); + Ok(()) + } else { + let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed + && to_type_scalar.kind() == ast::ScalarKind::Signed + { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + } + } else { + if to_type_scalar.size_of() >= from_type.size_of() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + } + }; + let wide_bit_value = unsafe { + conversion_fn( + self.builder, + same_width_bit_value, + wide_bit_type_llvm, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.emit_conversion_default( + wide_bit_value, + dst, + &wide_bit_type.into(), + from_space, + to_type, + to_space, + ) + } + } + } + (ast::Type::Vector(..), ast::Type::Scalar(..)) + | (ast::Type::Scalar(..), ast::Type::Array(..)) + | (ast::Type::Array(..), ast::Type::Scalar(..)) => { + let dst_type = get_type(self.context, to_type)?; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildBitCast(self.builder, src, dst_type, dst) + }); + Ok(()) + } + _ => todo!(), } } @@ -514,8 +757,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let fn_ = match data { - ast::ArithDetails::Integer(integer) => LLVMBuildAdd, - ast::ArithDetails::Float(float) => LLVMBuildFAdd, + ast::ArithDetails::Integer(..) => LLVMBuildAdd, + ast::ArithDetails::Float(..) => LLVMBuildFAdd, }; self.resolver.with_result(arguments.dst, |dst| unsafe { fn_(builder, src1, src2, dst) @@ -525,8 +768,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_st( &self, - data: ptx_parser::StData, - arguments: ptx_parser::StArgs, + data: ast::StData, + arguments: ast::StArgs, ) -> Result<(), TranslateError> { let ptr = self.resolver.value(arguments.src1)?; let value = self.resolver.value(arguments.src2)?; @@ -537,14 +780,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(()) } - fn emit_ret(&self, _data: ptx_parser::RetData) { + fn emit_ret(&self, _data: ast::RetData) { unsafe { LLVMBuildRetVoid(self.builder) }; } fn emit_call( &mut self, - data: ptx_parser::CallDetails, - arguments: ptx_parser::CallArgs, + data: ast::CallDetails, + arguments: ast::CallArgs, ) -> Result<(), TranslateError> { if cfg!(debug_assertions) { for (_, space) in data.return_arguments.iter() { @@ -558,14 +801,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { } } } - let name = match (&*data.return_arguments, &*arguments.return_arguments) { - ([], []) => LLVM_UNNAMED.as_ptr(), - ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst), + let name = match &*arguments.return_arguments { + [] => LLVM_UNNAMED.as_ptr(), + [dst] => self.resolver.get_or_add_raw(*dst), _ => todo!(), }; let type_ = get_function_type( self.context, - data.return_arguments.iter().map(|(type_, space)| type_), + data.return_arguments.iter().map(|(type_, ..)| type_), data.input_arguments .iter() .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), @@ -597,148 +840,1553 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_mov( &mut self, - _data: ptx_parser::MovDetails, - arguments: ptx_parser::MovArgs, + _data: ast::MovDetails, + arguments: ast::MovArgs, ) -> Result<(), TranslateError> { self.resolver .register(arguments.dst, self.resolver.value(arguments.src)?); Ok(()) } -} -fn get_pointer_type<'ctx>( - context: LLVMContextRef, - to_space: ast::StateSpace, -) -> Result { - Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) -} + fn emit_ptr_access(&mut self, ptr_access: PtrAccess) -> Result<(), TranslateError> { + let ptr_src = self.resolver.value(ptr_access.ptr_src)?; + let mut offset_src = self.resolver.value(ptr_access.offset_src)?; + let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8); + self.resolver.with_result(ptr_access.dst, |dst| unsafe { + LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst) + }); + Ok(()) + } -fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { - Ok(match type_ { - ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), - ast::Type::Vector(size, scalar) => { - let base_type = get_scalar_type(context, *scalar); - unsafe { LLVMVectorType(base_type, *size as u32) } - } - ast::Type::Array(vec, scalar, dimensions) => { - let mut underlying_type = get_scalar_type(context, *scalar); - if let Some(size) = vec { - underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; + fn emit_and(&mut self, arguments: ast::AndArgs) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAnd(builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_atom( + &mut self, + data: ast::AtomDetails, + arguments: ast::AtomArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let op = match data.op { + ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, + ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, + ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, + ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, + ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, + ast::AtomicOp::IncrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap } - if dimensions.is_empty() { - return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); + ast::AtomicOp::DecrementWrap => { + LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap } - dimensions - .iter() - .rfold(underlying_type, |result, dimension| unsafe { - LLVMArrayType2(result, *dimension as u64) - }) - } - ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?, - }) -} - -fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef { - match type_ { - ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, - ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { - LLVMInt8TypeInContext(context) - }, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { - LLVMInt16TypeInContext(context) - }, - ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { - LLVMInt32TypeInContext(context) - }, - ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { - LLVMInt64TypeInContext(context) - }, - ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, - ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, - ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, - ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, - ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), - ast::ScalarType::BF16x2 => todo!(), + ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, + ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin, + ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, + ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax, + ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, + ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, + ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, + }; + self.resolver.register(arguments.dst, unsafe { + LLVMZludaBuildAtomicRMW( + builder, + op, + src1, + src2, + get_scope(data.scope)?, + get_ordering(data.semantics), + ) + }); + Ok(()) } -} - -fn get_function_type<'a>( - context: LLVMContextRef, - mut return_args: impl ExactSizeIterator, - input_args: impl ExactSizeIterator>, -) -> Result { - let mut input_args: Vec<*mut llvm_zluda::LLVMType> = - input_args.collect::, _>>()?; - let return_type = match return_args.len() { - 0 => unsafe { LLVMVoidTypeInContext(context) }, - 1 => get_type(context, return_args.next().unwrap())?, - _ => todo!(), - }; - Ok(unsafe { - LLVMFunctionType( - return_type, - input_args.as_mut_ptr(), - input_args.len() as u32, - 0, - ) - }) -} -fn get_state_space(space: ast::StateSpace) -> Result { - match space { - ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), - ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(TranslateError::Todo), - ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::ParamFunc => Err(TranslateError::Todo), - ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), - ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), - ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), - ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), - ast::StateSpace::SharedCta => Err(TranslateError::Todo), - ast::StateSpace::SharedCluster => Err(TranslateError::Todo), + fn emit_atom_cas( + &mut self, + data: ast::AtomCasDetails, + arguments: ast::AtomCasArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + let success_ordering = get_ordering(data.semantics); + let failure_ordering = get_ordering_failure(data.semantics); + let temp = unsafe { + LLVMZludaBuildAtomicCmpXchg( + self.builder, + src1, + src2, + src3, + get_scope(data.scope)?, + success_ordering, + failure_ordering, + ) + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildExtractValue(self.builder, temp, 0, dst) + }); + Ok(()) } -} -struct ResolveIdent { - words: HashMap, - values: HashMap, -} + fn emit_bra(&self, arguments: ast::BraArgs) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let src = unsafe { LLVMValueAsBasicBlock(src) }; + unsafe { LLVMBuildBr(self.builder, src) }; + Ok(()) + } -impl ResolveIdent { - fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { - ResolveIdent { - words: HashMap::new(), - values: HashMap::new(), + fn emit_brev( + &mut self, + data: ast::ScalarType, + arguments: ast::BrevArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.bitreverse.i32", + 8 => c"llvm.bitreverse.i64", + _ => return Err(error_unreachable()), + }; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + let type_ = get_scalar_type(self.context, data); + let fn_type = get_function_type( + self.context, + iter::once(&data.into()), + iter::once(Ok(type_)), + )?; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; } + let mut src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) + }); + Ok(()) } - fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { - let str = match self.words.entry(word) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - let mut text = word.0.to_string(); - text.push('\0'); - entry.insert(text) + fn emit_ret_value( + &mut self, + values: Vec<(SpirvWord, ptx_parser::Type)>, + ) -> Result<(), TranslateError> { + match &*values { + [] => unsafe { LLVMBuildRetVoid(self.builder) }, + [(value, type_)] => { + let value = self.resolver.value(*value)?; + let type_ = get_type(self.context, type_)?; + let value = + unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) }; + unsafe { LLVMBuildRet(self.builder, value) } } + _ => todo!(), }; - fn_(&str[..str.len() - 1]) + Ok(()) } - fn get_or_add(&mut self, word: SpirvWord) -> &str { - self.get_or_ad_impl(word, |x| x) + fn emit_clz( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::ClzArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.size_of() { + 4 => c"llvm.ctlz.i32", + 8 => c"llvm.ctlz.i64", + _ => return Err(error_unreachable()), + }; + let type_ = get_scalar_type(self.context, data.into()); + let pred = get_scalar_type(self.context, ast::ScalarType::Pred); + let fn_type = get_function_type( + self.context, + iter::once(&ast::ScalarType::U32.into()), + [Ok(type_), Ok(pred)].into_iter(), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; + } + let src = self.resolver.value(arguments.src)?; + let false_ = unsafe { LLVMConstInt(pred, 0, 0) }; + let mut args = [src, false_]; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + args.as_mut_ptr(), + args.len() as u32, + dst, + ) + }); + Ok(()) } - fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 { - self.get_or_add(word).as_ptr().cast() + fn emit_mul( + &mut self, + data: ast::MulDetails, + arguments: ast::MulArgs, + ) -> Result<(), TranslateError> { + self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?; + Ok(()) } - fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { - self.values.insert(word, v); - } + fn emit_mul_impl( + &mut self, + data: ast::MulDetails, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let mul_fn = match data { + ast::MulDetails::Integer { control, type_ } => match control { + ast::MulIntControl::Low => LLVMBuildMul, + ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2), + ast::MulIntControl::Wide => { + return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1) + } + }, + ast::MulDetails::Float(..) => LLVMBuildFMul, + }; + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + Ok(self + .resolver + .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) })) + } + + fn emit_mul_high( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result { + let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?; + let shift_constant = + unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) }; + let shifted = unsafe { + LLVMBuildLShr( + self.builder, + wide_value, + shift_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let narrow_type = get_scalar_type(self.context, type_); + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildTrunc(self.builder, shifted, narrow_type, dst) + })) + } + + fn emit_mul_wide_impl( + &mut self, + type_: ptx_parser::ScalarType, + dst: Option, + src1: SpirvWord, + src2: SpirvWord, + ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> { + let src1 = self.resolver.value(src1)?; + let src2 = self.resolver.value(src2)?; + let wide_type = + unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) }; + let llvm_cast = match type_.kind() { + ptx_parser::ScalarKind::Signed => LLVMBuildSExt, + ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt, + _ => return Err(error_unreachable()), + }; + let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) }; + Ok(( + wide_type, + self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildMul(self.builder, src1, src2, dst) + }), + )) + } + + fn emit_cos( + &mut self, + _data: ast::FlushToZero, + arguments: ast::CosArgs, + ) -> Result<(), TranslateError> { + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let cos = self.emit_intrinsic( + c"llvm.cos.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![(self.resolver.value(arguments.src)?, llvm_f32)], + )?; + unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_or( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::OrArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildOr(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_xor( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::XorArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildXor(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> { + let src = self.resolver.value(vec_acccess.vector_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vec_acccess.member as _, + 0, + ) + }; + self.resolver + .with_result(vec_acccess.scalar_dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + Ok(()) + } + + fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> { + let vector_src = self.resolver.value(vector_write.vector_src)?; + let scalar_src = self.resolver.value(vector_write.scalar_src)?; + let index = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B8), + vector_write.member as _, + 0, + ) + }; + self.resolver + .with_result(vector_write.vector_dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst) + }); + Ok(()) + } + + fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> { + let i8_type = get_scalar_type(self.context, ast::ScalarType::B8); + if repack.is_extract { + let src = self.resolver.value(repack.packed)?; + for (index, dst) in repack.unpacked.iter().enumerate() { + let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + self.resolver.with_result(*dst, |dst| unsafe { + LLVMBuildExtractElement(self.builder, src, index, dst) + }); + } + } else { + let vector_type = get_type( + self.context, + &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ), + )?; + let mut temp_vec = unsafe { LLVMGetUndef(vector_type) }; + for (index, src_id) in repack.unpacked.iter().enumerate() { + let dst = if index == repack.unpacked.len() - 1 { + Some(repack.packed) + } else { + None + }; + let scalar_src = self.resolver.value(*src_id)?; + let index = unsafe { LLVMConstInt(i8_type, index as _, 0) }; + temp_vec = self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst) + }); + } + } + Ok(()) + } + + fn emit_div( + &mut self, + data: ptx_parser::DivDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let integer_div = match data { + ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv, + ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv, + ptx_parser::DivDetails::Float(float_div) => { + return self.emit_div_float(float_div, arguments) + } + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + integer_div(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_div_float( + &mut self, + float_div: ptx_parser::DivFloatDetails, + arguments: ptx_parser::DivArgs, + ) -> Result<(), TranslateError> { + let builder = self.builder; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let _rnd = match float_div.kind { + ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven, + ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode, + }; + let approx = match float_div.kind { + ptx_parser::DivFloatKind::Approx => { + LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc + } + ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone, + ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone, + }; + let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(builder, src1, src2, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) }; + if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind { + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div: + // div.full.f32 implements a relatively fast, full-range approximation that scales + // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not + // support rounding modifiers. The maximum ulp error is 2 across the full range of + // inputs. + // https://llvm.org/docs/LangRef.html#fpmath-metadata + let fpmath_value = + unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) }; + let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) }; + let mut md_node_content = [fpmath_value]; + let md_node = unsafe { + LLVMMDNodeInContext2( + self.context, + md_node_content.as_mut_ptr(), + md_node_content.len(), + ) + }; + let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) }; + let kind = unsafe { + LLVMGetMDKindIDInContext( + self.context, + "fpmath".as_ptr().cast(), + "fpmath".len() as u32, + ) + }; + unsafe { LLVMSetMetadata(fdiv, kind, md_node) }; + } + Ok(()) + } + + fn emit_cvta( + &mut self, + data: ptx_parser::CvtaDetails, + arguments: ptx_parser::CvtaArgs, + ) -> Result<(), TranslateError> { + let (from_space, to_space) = match data.direction { + ptx_parser::CvtaDirection::GenericToExplicit => { + (ast::StateSpace::Generic, data.state_space) + } + ptx_parser::CvtaDirection::ExplicitToGeneric => { + (data.state_space, ast::StateSpace::Generic) + } + }; + let from_type = get_pointer_type(self.context, from_space)?; + let dest_type = get_pointer_type(self.context, to_space)?; + let src = self.resolver.value(arguments.src)?; + let temp_ptr = + unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst) + }); + Ok(()) + } + + fn emit_sub( + &mut self, + data: ptx_parser::ArithDetails, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + match data { + ptx_parser::ArithDetails::Integer(arith_integer) => { + self.emit_sub_integer(arith_integer, arguments) + } + ptx_parser::ArithDetails::Float(arith_float) => { + self.emit_sub_float(arith_float, arguments) + } + } + } + + fn emit_sub_integer( + &mut self, + arith_integer: ptx_parser::ArithInteger, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + if arith_integer.saturate { + todo!() + } + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sub_float( + &mut self, + arith_float: ptx_parser::ArithFloat, + arguments: ptx_parser::SubArgs, + ) -> Result<(), TranslateError> { + if arith_float.saturate { + todo!() + } + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFSub(self.builder, src1, src2, dst) + }); + Ok(()) + } + + fn emit_sin( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::SinArgs, + ) -> Result<(), TranslateError> { + let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); + let sin = self.emit_intrinsic( + c"llvm.sin.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![(self.resolver.value(arguments.src)?, llvm_f32)], + )?; + unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } + Ok(()) + } + + fn emit_intrinsic( + &mut self, + name: &CStr, + dst: Option, + return_type: &ast::Type, + arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, + ) -> Result { + let fn_type = get_function_type( + self.context, + iter::once(return_type), + arguments.iter().map(|(_, type_)| Ok(*type_)), + )?; + let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; + if fn_ == ptr::null_mut() { + fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + } + let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::>(); + Ok(self.resolver.with_result_option(dst, |dst| unsafe { + LLVMBuildCall2( + self.builder, + fn_type, + fn_, + arguments.as_mut_ptr(), + arguments.len() as u32, + dst, + ) + })) + } + + fn emit_neg( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::NegArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float { + LLVMBuildFNeg + } else { + LLVMBuildNeg + }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst) + }); + Ok(()) + } + + fn emit_not( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::NotArgs, + ) -> Result<(), TranslateError> { + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildNot(self.builder, src, dst) + }); + Ok(()) + } + + fn emit_setp( + &mut self, + data: ptx_parser::SetpData, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + if arguments.dst2.is_some() { + todo!() + } + match data.cmp_op { + ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { + self.emit_setp_int(setp_compare_int, arguments) + } + ptx_parser::SetpCompareOp::Float(setp_compare_float) => { + self.emit_setp_float(setp_compare_float, arguments) + } + } + } + + fn emit_setp_int( + &mut self, + setp: ptx_parser::SetpCompareInt, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + let op = match setp { + ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, + ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, + ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT, + ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE, + ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT, + ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE, + ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT, + ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE, + ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, + ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst1, |dst1| unsafe { + LLVMBuildICmp(self.builder, op, src1, src2, dst1) + }); + Ok(()) + } + + fn emit_setp_float( + &mut self, + setp: ptx_parser::SetpCompareFloat, + arguments: ptx_parser::SetpArgs, + ) -> Result<(), TranslateError> { + let op = match setp { + ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, + ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, + ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT, + ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE, + ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT, + ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE, + ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ, + ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE, + ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT, + ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE, + ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT, + ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE, + ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, + ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst1, |dst1| unsafe { + LLVMBuildFCmp(self.builder, op, src1, src2, dst1) + }); + Ok(()) + } + + fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { + let predicate = self.resolver.value(cond.predicate)?; + let if_true = self.resolver.value(cond.if_true)?; + let if_false = self.resolver.value(cond.if_false)?; + unsafe { + LLVMBuildCondBr( + self.builder, + predicate, + LLVMValueAsBasicBlock(if_true), + LLVMValueAsBasicBlock(if_false), + ) + }; + Ok(()) + } + + fn emit_cvt( + &mut self, + data: ptx_parser::CvtDetails, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let dst_type = get_scalar_type(self.context, data.to); + let llvm_fn = match data.mode { + ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, + ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, + ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, + ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, + ptx_parser::CvtMode::SaturateUnsignedToSigned => { + return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) + } + ptx_parser::CvtMode::SaturateSignedToUnsigned => { + return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) + } + ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt, + ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc, + ptx_parser::CvtMode::FPRound { + integer_rounding, .. + } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + integer_rounding.unwrap_or(ast::RoundingMode::NearestEven), + arguments, + Some(LLVMBuildFPToSI), + ) + } + ptx_parser::CvtMode::SignedFromFP { rounding, .. } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + Some(LLVMBuildFPToSI), + ) + } + ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + rounding, + arguments, + Some(LLVMBuildFPToUI), + ) + } + ptx_parser::CvtMode::FPFromSigned(_) => todo!(), + ptx_parser::CvtMode::FPFromUnsigned(_) => todo!(), + }; + let src = self.resolver.value(arguments.src)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }); + Ok(()) + } + + fn emit_cvt_unsigned_to_signed_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1, + // so if it's downcast to a smaller type, it will be the maximum value + // of the smaller type + let max_value = match to { + ptx_parser::ScalarType::S8 => i8::MAX as u64, + ptx_parser::ScalarType::S16 => i16::MAX as u64, + ptx_parser::ScalarType::S32 => i32::MAX as u64, + ptx_parser::ScalarType::S64 => i64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let from_llvm = get_scalar_type(self.context, from); + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let clamped = self.emit_intrinsic( + c"llvm.umin", + None, + &from.into(), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (max, from_llvm), + ], + )?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, clamped, to_llvm, dst) + }); + Ok(()) + } + + fn emit_cvt_signed_to_unsigned_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let from_llvm = get_scalar_type(self.context, from); + let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) }; + let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); + let zero_clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) }, + None, + &from.into(), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (zero, from_llvm), + ], + )?; + // zero_clamped is now unsigned + let max_value = match to { + ptx_parser::ScalarType::U8 => u8::MAX as u64, + ptx_parser::ScalarType::U16 => u16::MAX as u64, + ptx_parser::ScalarType::U32 => u32::MAX as u64, + ptx_parser::ScalarType::U64 => u64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); + let fully_clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) }, + None, + &from.into(), + vec![(zero_clamped, from_llvm), (max, from_llvm)], + )?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, fully_clamped, to_llvm, dst) + }); + Ok(()) + } + + fn emit_cvt_float_to_int( + &mut self, + from: ast::ScalarType, + to: ast::ScalarType, + rounding: ast::RoundingMode, + arguments: ptx_parser::CvtArgs, + llvm_cast: Option< + unsafe extern "C" fn( + arg1: LLVMBuilderRef, + Val: LLVMValueRef, + DestTy: LLVMTypeRef, + Name: *const i8, + ) -> LLVMValueRef, + >, + ) -> Result<(), TranslateError> { + let prefix = match rounding { + ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", + ptx_parser::RoundingMode::Zero => "llvm.trunc", + ptx_parser::RoundingMode::NegativeInf => "llvm.floor", + ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", + }; + let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from)); + let rounded_float = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + None, + &from.into(), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, from), + )], + )?; + if let Some(llvm_cast) = llvm_cast { + let to = get_scalar_type(self.context, to); + let poisoned_dst = + unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFreeze(self.builder, poisoned_dst, dst) + }); + } else { + self.resolver.register(arguments.dst, rounded_float); + } + // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound + // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt__ which + // saturates by default and we don't care about NaNs anyway + /* + let cast_intrinsic = format!( + "{}.{}.{}\0", + llvm_cast, + LLVMTypeDisplay(to), + LLVMTypeDisplay(from) + ); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, + Some(arguments.dst), + &to.into(), + vec![(rounded_float, get_scalar_type(self.context, from))], + )?; + */ + Ok(()) + } + + fn emit_rsqrt( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::RsqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match data.type_ { + ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", + ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_sqrt( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::SqrtArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", + (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", + (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_rcp( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let intrinsic = match (data.type_, data.kind) { + (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", + (_, ast::RcpKind::Compliant(rnd)) => { + return self.emit_rcp_compliant(data, arguments, rnd) + } + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![(self.resolver.value(arguments.src)?, type_)], + )?; + Ok(()) + } + + fn emit_rcp_compliant( + &mut self, + data: ptx_parser::RcpData, + arguments: ptx_parser::RcpArgs, + _rnd: ast::RoundingMode, + ) -> Result<(), TranslateError> { + let type_ = get_scalar_type(self.context, data.type_); + let one = unsafe { LLVMConstReal(type_, 1.0) }; + let src = self.resolver.value(arguments.src)?; + let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFDiv(self.builder, one, src, dst) + }); + unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; + Ok(()) + } + + fn emit_shr( + &mut self, + data: ptx_parser::ShrData, + arguments: ptx_parser::ShrArgs, + ) -> Result<(), TranslateError> { + let shift_fn = match data.kind { + ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr, + ptx_parser::RightShiftKind::Logical => LLVMBuildLShr, + }; + self.emit_shift( + data.type_, + arguments.dst, + arguments.src1, + arguments.src2, + shift_fn, + ) + } + + fn emit_shl( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::ShlArgs, + ) -> Result<(), TranslateError> { + self.emit_shift( + type_, + arguments.dst, + arguments.src1, + arguments.src2, + LLVMBuildShl, + ) + } + + fn emit_shift( + &mut self, + type_: ast::ScalarType, + dst: SpirvWord, + src1: SpirvWord, + src2: SpirvWord, + llvm_fn: unsafe extern "C" fn( + LLVMBuilderRef, + LLVMValueRef, + LLVMValueRef, + *const i8, + ) -> LLVMValueRef, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(src1)?; + let shift_size = self.resolver.value(src2)?; + let integer_bits = type_.layout().size() * 8; + let integer_bits_constant = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::U32), + integer_bits as u64, + 0, + ) + }; + let should_clamp = unsafe { + LLVMBuildICmp( + self.builder, + LLVMIntPredicate::LLVMIntUGE, + shift_size, + integer_bits_constant, + LLVM_UNNAMED.as_ptr(), + ) + }; + let llvm_type = get_scalar_type(self.context, type_); + let zero = unsafe { LLVMConstNull(llvm_type) }; + let normalized_shift_size = if type_.layout().size() >= 4 { + unsafe { + LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) + } + } else { + unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } + }; + let shifted = unsafe { + llvm_fn( + self.builder, + src1, + normalized_shift_size, + LLVM_UNNAMED.as_ptr(), + ) + }; + self.resolver.with_result(dst, |dst| unsafe { + LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) + }); + Ok(()) + } + + fn emit_ex2( + &mut self, + data: ptx_parser::TypeFtz, + arguments: ptx_parser::Ex2Args, + ) -> Result<(), TranslateError> { + let intrinsic = match data.type_ { + ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16", + ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32", + _ => return Err(error_unreachable()), + }; + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &data.type_.into(), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, data.type_), + )], + )?; + Ok(()) + } + + fn emit_lg2( + &mut self, + _data: ptx_parser::FlushToZero, + arguments: ptx_parser::Lg2Args, + ) -> Result<(), TranslateError> { + self.emit_intrinsic( + c"llvm.amdgcn.log.f32", + Some(arguments.dst), + &ast::ScalarType::F32.into(), + vec![( + self.resolver.value(arguments.src)?, + get_scalar_type(self.context, ast::ScalarType::F32.into()), + )], + )?; + Ok(()) + } + + fn emit_selp( + &mut self, + _data: ptx_parser::ScalarType, + arguments: ptx_parser::SelpArgs, + ) -> Result<(), TranslateError> { + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + LLVMBuildSelect(self.builder, src3, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_rem( + &mut self, + data: ptx_parser::ScalarType, + arguments: ptx_parser::RemArgs, + ) -> Result<(), TranslateError> { + let llvm_fn = match data.kind() { + ptx_parser::ScalarKind::Unsigned => LLVMBuildURem, + ptx_parser::ScalarKind::Signed => LLVMBuildSRem, + _ => return Err(error_unreachable()), + }; + let src1 = self.resolver.value(arguments.src1)?; + let src2 = self.resolver.value(arguments.src2)?; + self.resolver.with_result(arguments.dst, |dst_name| unsafe { + llvm_fn(self.builder, src1, src2, dst_name) + }); + Ok(()) + } + + fn emit_popc( + &mut self, + type_: ptx_parser::ScalarType, + arguments: ptx_parser::PopcArgs, + ) -> Result<(), TranslateError> { + let intrinsic = match type_ { + ast::ScalarType::B32 => c"llvm.ctpop.i32", + ast::ScalarType::B64 => c"llvm.ctpop.i64", + _ => return Err(error_unreachable()), + }; + let llvm_type = get_scalar_type(self.context, type_); + self.emit_intrinsic( + intrinsic, + Some(arguments.dst), + &type_.into(), + vec![(self.resolver.value(arguments.src)?, llvm_type)], + )?; + Ok(()) + } + + fn emit_min( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MinArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + return Err(error_todo()) + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_().into(), + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], + )?; + Ok(()) + } + + fn emit_max( + &mut self, + data: ptx_parser::MinMaxDetails, + arguments: ptx_parser::MaxArgs, + ) -> Result<(), TranslateError> { + let llvm_prefix = match data { + ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", + ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { + return Err(error_todo()) + } + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", + }; + let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); + let llvm_type = get_scalar_type(self.context, data.type_()); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_().into(), + vec![ + (self.resolver.value(arguments.src1)?, llvm_type), + (self.resolver.value(arguments.src2)?, llvm_type), + ], + )?; + Ok(()) + } + + fn emit_fma( + &mut self, + data: ptx_parser::ArithFloat, + arguments: ptx_parser::FmaArgs, + ) -> Result<(), TranslateError> { + let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_)); + self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, + Some(arguments.dst), + &data.type_.into(), + vec![ + ( + self.resolver.value(arguments.src1)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src2)?, + get_scalar_type(self.context, data.type_), + ), + ( + self.resolver.value(arguments.src3)?, + get_scalar_type(self.context, data.type_), + ), + ], + )?; + Ok(()) + } + + fn emit_mad( + &mut self, + data: ptx_parser::MadDetails, + arguments: ptx_parser::MadArgs, + ) -> Result<(), TranslateError> { + let mul_control = match data { + ptx_parser::MadDetails::Float(mad_float) => { + return self.emit_fma( + mad_float, + ast::FmaArgs { + dst: arguments.dst, + src1: arguments.src1, + src2: arguments.src2, + src3: arguments.src3, + }, + ) + } + ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()), + ptx_parser::MadDetails::Integer { type_, control, .. } => { + ast::MulDetails::Integer { control, type_ } + } + }; + let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?; + let src3 = self.resolver.value(arguments.src3)?; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildAdd(self.builder, temp, src3, dst) + }); + Ok(()) + } + + fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> { + unsafe { + LLVMZludaBuildFence( + self.builder, + LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent, + get_scope_membar(data)?, + LLVM_UNNAMED.as_ptr(), + ) + }; + Ok(()) + } + + fn emit_prmt( + &mut self, + control: u16, + arguments: ptx_parser::PrmtArgs, + ) -> Result<(), TranslateError> { + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; + let mut components = [ + unsafe { LLVMConstInt(u32_type, components[0] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[1] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[2] as _, 0) }, + unsafe { LLVMConstInt(u32_type, components[3] as _, 0) }, + ]; + let components_indices = + unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) }; + let src1 = self.resolver.value(arguments.src1)?; + let src1_vector = + unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + let src2 = self.resolver.value(arguments.src2)?; + let src2_vector = + unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildShuffleVector( + self.builder, + src1_vector, + src2_vector, + components_indices, + dst, + ) + }); + Ok(()) + } + + /* + // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` + // Should be available in LLVM 19 + fn with_rounding(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T { + let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32); + let void_type = unsafe { LLVMVoidTypeInContext(self.context) }; + let get_rounding = c"llvm.get.rounding"; + let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) }; + let mut get_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) }; + if get_rounding_fn == ptr::null_mut() { + get_rounding_fn = unsafe { + LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type) + }; + } + let set_rounding = c"llvm.set.rounding"; + let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) }; + let mut set_rounding_fn = + unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) }; + if set_rounding_fn == ptr::null_mut() { + set_rounding_fn = unsafe { + LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type) + }; + } + let mut preserved_rounding_mode = unsafe { + LLVMBuildCall2( + self.builder, + get_rounding_fn_type, + get_rounding_fn, + ptr::null_mut(), + 0, + LLVM_UNNAMED.as_ptr(), + ) + }; + let mut requested_rounding = unsafe { + LLVMConstInt( + get_scalar_type(self.context, ast::ScalarType::B32), + rounding_to_llvm(rnd) as u64, + 0, + ) + }; + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut requested_rounding, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + let result = fn_(self); + unsafe { + LLVMBuildCall2( + self.builder, + set_rounding_fn_type, + set_rounding_fn, + &mut preserved_rounding_mode, + 1, + LLVM_UNNAMED.as_ptr(), + ) + }; + result + } + */ +} + +fn get_pointer_type<'ctx>( + context: LLVMContextRef, + to_space: ast::StateSpace, +) -> Result { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) +} + +// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes +fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup-one-as", + ast::MemScope::Gpu => c"agent-one-as", + ast::MemScope::Sys => c"one-as", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + +fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> { + Ok(match scope { + ast::MemScope::Cta => c"workgroup", + ast::MemScope::Gpu => c"agent", + ast::MemScope::Sys => c"", + ast::MemScope::Cluster => todo!(), + } + .as_ptr()) +} + +fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease, + } +} + +fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { + match semantics { + ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, + ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, + } +} + +fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { + Ok(match type_ { + ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), + ast::Type::Vector(size, scalar) => { + let base_type = get_scalar_type(context, *scalar); + unsafe { LLVMVectorType(base_type, *size as u32) } + } + ast::Type::Array(vec, scalar, dimensions) => { + let mut underlying_type = get_scalar_type(context, *scalar); + if let Some(size) = vec { + underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; + } + if dimensions.is_empty() { + return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); + } + dimensions + .iter() + .rfold(underlying_type, |result, dimension| unsafe { + LLVMArrayType2(result, *dimension as u64) + }) + } + ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?, + }) +} + +fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef { + match type_ { + ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, + ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { + LLVMInt8TypeInContext(context) + }, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { + LLVMInt16TypeInContext(context) + }, + ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { + LLVMInt32TypeInContext(context) + }, + ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { + LLVMInt64TypeInContext(context) + }, + ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, + ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, + ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, + ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, + ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, + ast::ScalarType::U16x2 => todo!(), + ast::ScalarType::S16x2 => todo!(), + ast::ScalarType::F16x2 => todo!(), + ast::ScalarType::BF16x2 => todo!(), + } +} + +fn get_function_type<'a>( + context: LLVMContextRef, + mut return_args: impl ExactSizeIterator, + input_args: impl ExactSizeIterator>, +) -> Result { + let mut input_args = input_args.collect::, _>>()?; + let return_type = match return_args.len() { + 0 => unsafe { LLVMVoidTypeInContext(context) }, + 1 => get_type(context, return_args.next().unwrap())?, + _ => todo!(), + }; + Ok(unsafe { + LLVMFunctionType( + return_type, + input_args.as_mut_ptr(), + input_args.len() as u32, + 0, + ) + }) +} + +fn get_state_space(space: ast::StateSpace) -> Result { + match space { + ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), + ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), + ast::StateSpace::Param => Err(TranslateError::Todo), + ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), + ast::StateSpace::ParamFunc => Err(TranslateError::Todo), + ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), + ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), + ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), + ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), + ast::StateSpace::SharedCta => Err(TranslateError::Todo), + ast::StateSpace::SharedCluster => Err(TranslateError::Todo), + } +} + +struct ResolveIdent { + words: HashMap, + values: HashMap, +} + +impl ResolveIdent { + fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { + ResolveIdent { + words: HashMap::new(), + values: HashMap::new(), + } + } + + fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { + let str = match self.words.entry(word) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let mut text = word.0.to_string(); + text.push('\0'); + entry.insert(text) + } + }; + fn_(&str[..str.len() - 1]) + } + + fn get_or_add(&mut self, word: SpirvWord) -> &str { + self.get_or_ad_impl(word, |x| x) + } + + fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 { + self.get_or_add(word).as_ptr().cast() + } + + fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { + self.values.insert(word, v); + } fn value(&self, word: SpirvWord) -> Result { self.values @@ -747,8 +2395,57 @@ impl ResolveIdent { .ok_or_else(|| error_unreachable()) } - fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) { + fn with_result( + &mut self, + word: SpirvWord, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); self.register(word, t); + t + } + + fn with_result_option( + &mut self, + word: Option, + fn_: impl FnOnce(*const i8) -> LLVMValueRef, + ) -> LLVMValueRef { + match word { + Some(word) => self.with_result(word, fn_), + None => fn_(LLVM_UNNAMED.as_ptr()), + } + } +} + +struct LLVMTypeDisplay(ast::ScalarType); + +impl std::fmt::Display for LLVMTypeDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + ast::ScalarType::Pred => write!(f, "i1"), + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), + ptx_parser::ScalarType::B128 => write!(f, "i128"), + ast::ScalarType::F16 => write!(f, "f16"), + ptx_parser::ScalarType::BF16 => write!(f, "bfloat"), + ast::ScalarType::F32 => write!(f, "f32"), + ast::ScalarType::F64 => write!(f, "f64"), + ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), + ast::ScalarType::F16x2 => write!(f, "v2f16"), + ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), + } + } +} + +/* +fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { + match this { + ptx_parser::RoundingMode::Zero => 0, + ptx_parser::RoundingMode::NearestEven => 1, + ptx_parser::RoundingMode::PositiveInf => 2, + ptx_parser::RoundingMode::NegativeInf => 3, } } +*/ diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs deleted file mode 100644 index 120a477..0000000 --- a/ptx/src/pass/emit_spirv.rs +++ /dev/null @@ -1,2762 +0,0 @@ -use super::*; -use half::f16; -use ptx_parser as ast; -use rspirv::{binary::Assemble, dr}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - mem, -}; - -pub(super) fn run<'input>( - mut builder: dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - call_map: MethodsCallMap<'input>, - denorm_information: HashMap< - ptx_parser::MethodName, - HashMap, - >, - directives: Vec>, -) -> Result<(dr::Module, HashMap, CString), TranslateError> { - builder.set_version(1, 3); - emit_capabilities(&mut builder); - emit_extensions(&mut builder); - let opencl_id = emit_opencl_import(&mut builder); - emit_memory_model(&mut builder); - let mut map = TypeWordMap::new(&mut builder); - //emit_builtins(&mut builder, &mut map, &id_defs); - let mut kernel_info = HashMap::new(); - let (build_options, should_flush_denorms) = - emit_denorm_build_string(&call_map, &denorm_information); - let (directives, globals_use_map) = get_globals_use_map(directives); - emit_directives( - &mut builder, - &mut map, - &id_defs, - opencl_id, - should_flush_denorms, - &call_map, - globals_use_map, - directives, - &mut kernel_info, - )?; - Ok((builder.module(), kernel_info, build_options)) -} - -fn emit_capabilities(builder: &mut dr::Builder) { - builder.capability(spirv::Capability::GenericPointer); - builder.capability(spirv::Capability::Linkage); - builder.capability(spirv::Capability::Addresses); - builder.capability(spirv::Capability::Kernel); - builder.capability(spirv::Capability::Int8); - builder.capability(spirv::Capability::Int16); - builder.capability(spirv::Capability::Int64); - builder.capability(spirv::Capability::Float16); - builder.capability(spirv::Capability::Float64); - builder.capability(spirv::Capability::DenormFlushToZero); - // TODO: re-enable when Intel float control extension works - //builder.capability(spirv::Capability::FunctionFloatControlINTEL); -} - -// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html -fn emit_extensions(builder: &mut dr::Builder) { - // TODO: re-enable when Intel float control extension works - //builder.extension("SPV_INTEL_float_controls2"); - builder.extension("SPV_KHR_float_controls"); - builder.extension("SPV_KHR_no_integer_wrap_decoration"); -} - -fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { - builder.ext_inst_import("OpenCL.std") -} - -fn emit_memory_model(builder: &mut dr::Builder) { - builder.memory_model( - spirv::AddressingModel::Physical64, - spirv::MemoryModel::OpenCL, - ); -} - -struct TypeWordMap { - void: spirv::Word, - complex: HashMap, - constants: HashMap<(SpirvType, u64), SpirvWord>, -} - -impl TypeWordMap { - fn new(b: &mut dr::Builder) -> TypeWordMap { - let void = b.type_void(None); - TypeWordMap { - void: void, - complex: HashMap::::new(), - constants: HashMap::new(), - } - } - - fn void(&self) -> spirv::Word { - self.void - } - - fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { - let key: SpirvScalarKey = t.into(); - self.get_or_add_spirv_scalar(b, key) - } - - fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { - *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { - SpirvWord(match key { - SpirvScalarKey::B8 => b.type_int(None, 8, 0), - SpirvScalarKey::B16 => b.type_int(None, 16, 0), - SpirvScalarKey::B32 => b.type_int(None, 32, 0), - SpirvScalarKey::B64 => b.type_int(None, 64, 0), - SpirvScalarKey::F16 => b.type_float(None, 16), - SpirvScalarKey::F32 => b.type_float(None, 32), - SpirvScalarKey::F64 => b.type_float(None, 64), - SpirvScalarKey::Pred => b.type_bool(None), - SpirvScalarKey::F16x2 => todo!(), - }) - }) - } - - fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { - match t { - SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), - SpirvType::Pointer(ref typ, storage) => { - let base = self.get_or_add(b, *typ.clone()); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) - } - SpirvType::Vector(typ, len) => { - let base = self.get_or_add_spirv_scalar(b, typ); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) - } - SpirvType::Array(typ, array_dimensions) => { - let (base_type, length) = match &*array_dimensions { - &[] => { - return self.get_or_add(b, SpirvType::Base(typ)); - } - &[len] => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self.get_or_add_spirv_scalar(b, typ); - let len_const = b.constant_u32(u32_type.0, None, len); - (base, len_const) - } - array_dimensions => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self - .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); - let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); - (base, len_const) - } - }; - *self - .complex - .entry(SpirvType::Array(typ, array_dimensions)) - .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) - } - SpirvType::Func(ref out_params, ref in_params) => { - let out_t = match out_params { - Some(p) => self.get_or_add(b, *p.clone()), - None => SpirvWord(self.void()), - }; - let in_t = in_params - .iter() - .map(|t| self.get_or_add(b, t.clone()).0) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) - } - SpirvType::Struct(ref underlying) => { - let underlying_ids = underlying - .iter() - .map(|t| self.get_or_add_spirv_scalar(b, *t).0) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) - } - } - } - - fn get_or_add_fn( - &mut self, - b: &mut dr::Builder, - in_params: impl Iterator, - mut out_params: impl ExactSizeIterator, - ) -> (SpirvWord, SpirvWord) { - let (out_args, out_spirv_type) = if out_params.len() == 0 { - (None, SpirvWord(self.void())) - } else if out_params.len() == 1 { - let arg_as_key = out_params.next().unwrap(); - ( - Some(Box::new(arg_as_key.clone())), - self.get_or_add(b, arg_as_key), - ) - } else { - // TODO: support multiple return values - todo!() - }; - ( - out_spirv_type, - self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), - ) - } - - fn get_or_add_constant( - &mut self, - b: &mut dr::Builder, - typ: &ast::Type, - init: &[u8], - ) -> Result { - Ok(match typ { - ast::Type::Scalar(t) => match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v), - ), - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v, - |b, result_type, v| b.constant_u64(result_type, None, v), - ), - ast::ScalarType::F16 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u16>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), - ), - ast::ScalarType::F32 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u32>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v), - ), - ast::ScalarType::F64 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u64>(v) }, - |b, result_type, v| b.constant_f64(result_type, None, v), - ), - ast::ScalarType::F16x2 => return Err(TranslateError::Todo), - ast::ScalarType::Pred => self.get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| { - if v == 0 { - b.constant_false(result_type, None) - } else { - b.constant_true(result_type, None) - } - }, - ), - ast::ScalarType::S16x2 - | ast::ScalarType::U16x2 - | ast::ScalarType::BF16 - | ast::ScalarType::BF16x2 - | ast::ScalarType::B128 => todo!(), - }, - ast::Type::Vector(len, typ) => { - let result_type = - self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); - let size_of_t = typ.size_of(); - let components = (0..*len) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - ast::Type::Array(_, typ, dims) => match dims.as_slice() { - [] => return Err(error_unreachable()), - [dim] => { - let result_type = self - .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); - let size_of_t = typ.size_of(); - let components = (0..*dim) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - [first_dim, rest @ ..] => { - let result_type = self.get_or_add( - b, - SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), - ); - let size_of_t = rest - .iter() - .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); - let components = (0..*first_dim) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Array(None, *typ, rest.to_vec()), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - }, - ast::Type::Pointer(..) => return Err(error_unreachable()), - }) - } - - fn get_or_add_constant_single< - T: Copy, - CastAsU64: FnOnce(T) -> u64, - InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, - >( - &mut self, - b: &mut dr::Builder, - key: ast::ScalarType, - init: &[u8], - cast: CastAsU64, - f: InsertConstant, - ) -> SpirvWord { - let value = unsafe { *(init.as_ptr() as *const T) }; - let value_64 = cast(value); - let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); - match self.constants.get(&ht_key) { - Some(value) => *value, - None => { - let spirv_type = self.get_or_add_scalar(b, key); - let result = SpirvWord(f(b, spirv_type.0, value)); - self.constants.insert(ht_key, result); - result - } - } - } -} - -#[derive(PartialEq, Eq, Hash, Clone)] -enum SpirvType { - Base(SpirvScalarKey), - Vector(SpirvScalarKey, u8), - Array(SpirvScalarKey, Vec), - Pointer(Box, spirv::StorageClass), - Func(Option>, Vec), - Struct(Vec), -} - -impl SpirvType { - fn new(t: ast::Type) -> Self { - match t { - ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), - ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( - Box::new(SpirvType::Base(pointer_t.into())), - space_to_spirv(space), - ), - } - } - - fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { - let key = Self::new(t); - SpirvType::Pointer(Box::new(key), outer_space) - } -} - -impl From for SpirvType { - fn from(t: ast::ScalarType) -> Self { - SpirvType::Base(t.into()) - } -} -// SPIR-V integer type definitions are signless, more below: -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -enum SpirvScalarKey { - B8, - B16, - B32, - B64, - F16, - F32, - F64, - Pred, - F16x2, -} - -impl From for SpirvScalarKey { - fn from(t: ast::ScalarType) -> Self { - match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - SpirvScalarKey::B16 - } - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { - SpirvScalarKey::B32 - } - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { - SpirvScalarKey::B64 - } - ast::ScalarType::F16 => SpirvScalarKey::F16, - ast::ScalarType::F32 => SpirvScalarKey::F32, - ast::ScalarType::F64 => SpirvScalarKey::F64, - ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, - ast::ScalarType::Pred => SpirvScalarKey::Pred, - ast::ScalarType::S16x2 - | ast::ScalarType::U16x2 - | ast::ScalarType::BF16 - | ast::ScalarType::BF16x2 - | ast::ScalarType::B128 => todo!(), - } - } -} - -fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { - match this { - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Local => spirv::StorageClass::Function, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Param => spirv::StorageClass::Function, - ast::StateSpace::Reg => spirv::StorageClass::Function, - ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta => todo!(), - } -} - -// TODO: remove this once we have pef-function support for denorms -fn emit_denorm_build_string<'input>( - call_map: &MethodsCallMap, - denorm_information: &HashMap< - ast::MethodName<'input, SpirvWord>, - HashMap, - >, -) -> (CString, bool) { - let denorm_counts = denorm_information - .iter() - .map(|(method, meth_denorm)| { - let f16_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - let f32_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - (method, (f16_count + f32_count)) - }) - .collect::>(); - let mut flush_over_preserve = 0; - for (kernel, children) in call_map.kernels() { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Kernel(kernel)) - .unwrap_or(&0); - for child_fn in children { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Func(*child_fn)) - .unwrap_or(&0); - } - } - if flush_over_preserve > 0 { - ( - CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), - true, - ) - } else { - (CString::new("-ze-take-global-address").unwrap(), false) - } -} - -fn get_globals_use_map<'input>( - directives: Vec>, -) -> ( - Vec>, - HashMap, HashSet>, -) { - let mut known_globals = HashSet::new(); - for directive in directives.iter() { - match directive { - Directive::Variable(_, ast::Variable { name, .. }) => { - known_globals.insert(*name); - } - Directive::Method(..) => {} - } - } - let mut symbol_uses_map = HashMap::new(); - let directives = directives - .into_iter() - .map(|directive| match directive { - Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, - Directive::Method(Function { - func_decl, - body: Some(mut statements), - globals, - import_as, - tuning, - linkage, - }) => { - let method_name = func_decl.borrow().name; - statements = statements - .into_iter() - .map(|statement| { - statement.visit_map( - &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { - if known_globals.contains(&symbol) { - multi_hash_map_append( - &mut symbol_uses_map, - method_name, - symbol, - ); - } - Ok::<_, TranslateError>(symbol) - }, - ) - }) - .collect::, _>>() - .unwrap(); - Directive::Method(Function { - func_decl, - body: Some(statements), - globals, - import_as, - tuning, - linkage, - }) - } - }) - .collect::>(); - (directives, symbol_uses_map) -} - -fn emit_directives<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl_id: spirv::Word, - should_flush_denorms: bool, - call_map: &MethodsCallMap<'input>, - globals_use_map: HashMap, HashSet>, - directives: Vec>, - kernel_info: &mut HashMap, -) -> Result<(), TranslateError> { - let empty_body = Vec::new(); - for d in directives.iter() { - match d { - Directive::Variable(linking, var) => { - emit_variable(builder, map, id_defs, *linking, &var)?; - } - Directive::Method(f) => { - let f_body = match &f.body { - Some(f) => f, - None => { - if f.linkage.contains(ast::LinkingDirective::EXTERN) { - &empty_body - } else { - continue; - } - } - }; - for var in f.globals.iter() { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - let func_decl = (*f.func_decl).borrow(); - let fn_id = emit_function_header( - builder, - map, - &id_defs, - &*func_decl, - call_map, - &globals_use_map, - kernel_info, - )?; - if matches!(func_decl.name, ast::MethodName::Kernel(_)) { - if should_flush_denorms { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [16], - ); - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [32], - ); - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [64], - ); - } - // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::ContractionOff, - [], - ); - for t in f.tuning.iter() { - match *t { - ast::TuningDirective::MaxNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, - [nx, ny, nz], - ); - } - ast::TuningDirective::ReqNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::LocalSize, - [nx, ny, nz], - ); - } - // Too architecture specific - ast::TuningDirective::MaxNReg(..) - | ast::TuningDirective::MinNCtaPerSm(..) => {} - } - } - } - emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; - emit_function_linkage(builder, id_defs, f, fn_id)?; - builder.select_block(None)?; - builder.end_function()?; - } - } - } - Ok(()) -} - -fn emit_variable<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - linking: ast::LinkingDirective, - var: &ast::Variable, -) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.state_space { - ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { - (false, spirv::StorageClass::Function) - } - ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), - ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), - ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), - ast::StateSpace::Generic => todo!(), - ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta => todo!(), - }; - let initalizer = if var.array_init.len() > 0 { - Some( - map.get_or_add_constant( - builder, - &ast::Type::from(var.v_type.clone()), - &*var.array_init, - )? - .0, - ) - } else if must_init { - let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); - Some(builder.constant_null(type_id.0, None)) - } else { - None - }; - let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); - builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); - if let Some(align) = var.align { - builder.decorate( - var.name.0, - spirv::Decoration::Alignment, - [dr::Operand::LiteralInt32(align)].iter().cloned(), - ); - } - if var.state_space != ast::StateSpace::Shared - || !linking.contains(ast::LinkingDirective::EXTERN) - { - emit_linking_decoration(builder, id_defs, None, var.name, linking); - } - Ok(()) -} - -fn emit_function_header<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - defined_globals: &GlobalStringIdResolver<'input>, - func_decl: &ast::MethodDeclaration<'input, SpirvWord>, - call_map: &MethodsCallMap<'input>, - globals_use_map: &HashMap, HashSet>, - kernel_info: &mut HashMap, -) -> Result { - if let ast::MethodName::Kernel(name) = func_decl.name { - let args_lens = func_decl - .input_arguments - .iter() - .map(|param| { - ( - type_size_of(¶m.v_type), - matches!(param.v_type, ast::Type::Pointer(..)), - ) - }) - .collect(); - kernel_info.insert( - name.to_string(), - KernelInfo { - arguments_sizes: args_lens, - uses_shared_mem: func_decl.shared_mem.is_some(), - }, - ); - } - let (ret_type, func_type) = get_function_type( - builder, - map, - effective_input_arguments(func_decl).map(|(_, typ)| typ), - &func_decl.return_arguments, - ); - let fn_id = match func_decl.name { - ast::MethodName::Kernel(name) => { - let fn_id = defined_globals.get_id(name)?; - let interface = globals_use_map - .get(&ast::MethodName::Kernel(name)) - .into_iter() - .flatten() - .copied() - .chain({ - call_map - .get_kernel_children(name) - .copied() - .flat_map(|subfunction| { - globals_use_map - .get(&ast::MethodName::Func(subfunction)) - .into_iter() - .flatten() - .copied() - }) - .into_iter() - }) - .map(|word| word.0) - .collect::>(); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); - fn_id - } - ast::MethodName::Func(name) => name, - }; - builder.begin_function( - ret_type.0, - Some(fn_id.0), - spirv::FunctionControl::NONE, - func_type.0, - )?; - for (name, typ) in effective_input_arguments(func_decl) { - let result_type = map.get_or_add(builder, typ); - builder.function_parameter(Some(name.0), result_type.0)?; - } - Ok(fn_id) -} - -pub fn type_size_of(this: &ast::Type) -> usize { - match this { - ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(_, typ, len) => len - .iter() - .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(..) => mem::size_of::(), - } -} -fn emit_function_body_ops<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl: spirv::Word, - func: &[ExpandedStatement], -) -> Result<(), TranslateError> { - for s in func { - match s { - Statement::Label(id) => { - if builder.selected_block().is_some() { - builder.branch(id.0)?; - } - builder.begin_block(Some(id.0))?; - } - _ => { - if builder.selected_block().is_none() && builder.selected_function().is_some() { - builder.begin_block(None)?; - } - } - } - match s { - Statement::Label(_) => (), - Statement::Variable(var) => { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - Statement::Constant(cnst) => { - let typ_id = map.get_or_add_scalar(builder, cnst.typ); - match (cnst.typ, cnst.value) { - (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); - } - (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); - } - (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { - builder.constant_f32( - typ_id.0, - Some(cnst.dst.0), - f16::from_f32(value).to_f32(), - ); - } - (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { - builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { - builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { - builder.constant_f32( - typ_id.0, - Some(cnst.dst.0), - f16::from_f64(value).to_f32(), - ); - } - (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { - builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); - } - (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { - builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst.0)); - } else { - builder.constant_true(bool_type, Some(cnst.dst.0)); - } - } - (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst.0)); - } else { - builder.constant_true(bool_type, Some(cnst.dst.0)); - } - } - _ => return Err(error_mismatched_type()), - } - } - Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, - Statement::Conditional(bra) => { - builder.branch_conditional( - bra.predicate.0, - bra.if_true.0, - bra.if_false.0, - iter::empty(), - )?; - } - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { - // TODO: implement properly - let zero = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U64), - &vec_repr(0u64), - )?; - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); - builder.copy_object(result_type.0, Some(dst.0), zero.0)?; - } - Statement::Instruction(inst) => match inst { - ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), - ast::Instruction::Call { data, arguments } => { - let (result_type, result_id) = - match (&*data.return_arguments, &*arguments.return_arguments) { - ([(type_, space)], [id]) => { - if *space != ast::StateSpace::Reg { - return Err(error_unreachable()); - } - ( - map.get_or_add(builder, SpirvType::new(type_.clone())).0, - Some(id.0), - ) - } - ([], []) => (map.void(), None), - _ => todo!(), - }; - let arg_list = arguments - .input_arguments - .iter() - .map(|id| id.0) - .collect::>(); - builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; - } - ast::Instruction::Abs { data, arguments } => { - emit_abs(builder, map, opencl, data, arguments)? - } - // SPIR-V does not support marking jumps as guaranteed-converged - ast::Instruction::Bra { arguments, .. } => { - builder.branch(arguments.src.0)?; - } - ast::Instruction::Ld { data, arguments } => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); - builder.load( - result_type.0, - Some(arguments.dst.0), - arguments.src.0, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - type_size_of(&ast::Type::from(data.typ.clone())) as u32, - )] - .iter() - .cloned(), - )?; - } - ast::Instruction::St { data, arguments } => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - builder.store( - arguments.src1.0, - arguments.src2.0, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - type_size_of(&ast::Type::from(data.typ.clone())) as u32, - )] - .iter() - .cloned(), - )?; - } - // SPIR-V does not support ret as guaranteed-converged - ast::Instruction::Ret { .. } => builder.ret()?, - ast::Instruction::Mov { data, arguments } => { - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); - builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Mul { data, arguments } => match data { - ast::MulDetails::Integer { type_, control } => { - emit_mul_int(builder, map, opencl, *type_, *control, arguments)? - } - ast::MulDetails::Float(ref ctr) => { - emit_mul_float(builder, map, ctr, arguments)? - } - }, - ast::Instruction::Add { data, arguments } => match data { - ast::ArithDetails::Integer(desc) => { - emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? - } - ast::ArithDetails::Float(desc) => { - emit_add_float(builder, map, desc, arguments)? - } - }, - ast::Instruction::Setp { data, arguments } => { - if arguments.dst2.is_some() { - todo!() - } - emit_setp(builder, map, data, arguments)?; - } - ast::Instruction::Not { data, arguments } => { - let result_type = map.get_or_add(builder, SpirvType::from(*data)); - let result_id = Some(arguments.dst.0); - let operand = arguments.src; - match data { - ast::ScalarType::Pred => { - logical_not(builder, result_type.0, result_id, operand.0) - } - _ => builder.not(result_type.0, result_id, operand.0), - }?; - } - ast::Instruction::Shl { data, arguments } => { - let full_type = ast::Type::Scalar(*data); - let size_of = type_size_of(&full_type); - let result_type = map.get_or_add(builder, SpirvType::new(full_type)); - let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; - builder.shift_left_logical( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - ast::Instruction::Shr { data, arguments } => { - let full_type = ast::ScalarType::from(data.type_); - let size_of = full_type.size_of(); - let result_type = map.get_or_add_scalar(builder, full_type).0; - let offset_src = - insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; - match data.kind { - ptx_parser::RightShiftKind::Arithmetic => { - builder.shift_right_arithmetic( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - ptx_parser::RightShiftKind::Logical => { - builder.shift_right_logical( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - } - } - ast::Instruction::Cvt { data, arguments } => { - emit_cvt(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Cvta { data, arguments } => { - // This would be only meaningful if const/slm/global pointers - // had a different format than generic pointers, but they don't pretty much by ptx definition - // Honestly, I have no idea why this instruction exists and is emitted by the compiler - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); - builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::SetpBool { .. } => todo!(), - ast::Instruction::Mad { data, arguments } => match data { - ast::MadDetails::Integer { - type_, - control, - saturate, - } => { - if *saturate { - todo!() - } - if type_.kind() == ast::ScalarKind::Signed { - emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? - } else { - emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? - } - } - ast::MadDetails::Float(desc) => { - emit_mad_float(builder, map, opencl, desc, arguments)? - } - }, - ast::Instruction::Fma { data, arguments } => { - emit_fma_float(builder, map, opencl, data, arguments)? - } - ast::Instruction::Or { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data).0; - if *data == ast::ScalarType::Pred { - builder.logical_or( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } else { - builder.bitwise_or( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - } - ast::Instruction::Sub { data, arguments } => match data { - ast::ArithDetails::Integer(desc) => { - emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; - } - ast::ArithDetails::Float(desc) => { - emit_sub_float(builder, map, desc, arguments)?; - } - }, - ast::Instruction::Min { data, arguments } => { - emit_min(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Max { data, arguments } => { - emit_max(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Rcp { data, arguments } => { - emit_rcp(builder, map, opencl, data, arguments)?; - } - ast::Instruction::And { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data); - if *data == ast::ScalarType::Pred { - builder.logical_and( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } else { - builder.bitwise_and( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - } - ast::Instruction::Selp { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data); - builder.select( - result_type.0, - Some(arguments.dst.0), - arguments.src3.0, - arguments.src1.0, - arguments.src2.0, - )?; - } - // TODO: implement named barriers - ast::Instruction::Bar { data, arguments } => { - let workgroup_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(spirv::Scope::Workgroup as u32), - )?; - let barrier_semantics = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr( - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - )?; - builder.control_barrier( - workgroup_scope.0, - workgroup_scope.0, - barrier_semantics.0, - )?; - } - ast::Instruction::Atom { data, arguments } => { - emit_atom(builder, map, data, arguments)?; - } - ast::Instruction::AtomCas { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope_to_spirv(data.scope) as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics_to_spirv(data.semantics).bits()), - )?; - builder.atomic_compare_exchange( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - memory_const.0, - semantics_const.0, - semantics_const.0, - arguments.src3.0, - arguments.src2.0, - )?; - } - ast::Instruction::Div { data, arguments } => match data { - ast::DivDetails::Unsigned(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.u_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::DivDetails::Signed(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.s_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::DivDetails::Float(t) => { - let result_type = map.get_or_add_scalar(builder, t.type_.into()); - builder.f_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - emit_float_div_decoration(builder, arguments.dst, t.kind); - } - }, - ast::Instruction::Sqrt { data, arguments } => { - emit_sqrt(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Rsqrt { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_.into()); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::rsqrt as spirv::Word, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Neg { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_); - let negate_func = if data.type_.kind() == ast::ScalarKind::Float { - dr::Builder::f_negate - } else { - dr::Builder::s_negate - }; - negate_func( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src.0, - )?; - } - ast::Instruction::Sin { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::sin as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Cos { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::cos as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Lg2 { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::log2 as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Ex2 { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::exp2 as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Clz { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::clz as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Brev { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Popc { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Xor { data, arguments } => { - let builder_fn: fn( - &mut dr::Builder, - u32, - Option, - u32, - u32, - ) -> Result = match data { - ast::ScalarType::Pred => emit_logical_xor_spirv, - _ => dr::Builder::bitwise_xor, - }; - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder_fn( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::Instruction::Bfe { .. } - | ast::Instruction::Bfi { .. } - | ast::Instruction::Activemask { .. } => { - // Should have beeen replaced with a funciton call earlier - return Err(error_unreachable()); - } - - ast::Instruction::Rem { data, arguments } => { - let builder_fn = if data.kind() == ast::ScalarKind::Signed { - dr::Builder::s_mod - } else { - dr::Builder::u_mod - }; - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder_fn( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::Instruction::Prmt { data, arguments } => { - let control = *data as u32; - let components = [ - (control >> 0) & 0b1111, - (control >> 4) & 0b1111, - (control >> 8) & 0b1111, - (control >> 12) & 0b1111, - ]; - if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo); - } - let vec4_b8_type = - map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); - let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; - let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; - let dst_vector = builder.vector_shuffle( - vec4_b8_type.0, - None, - src1_vector, - src2_vector, - components, - )?; - builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; - } - ast::Instruction::Membar { data } => { - let (scope, semantics) = match data { - ast::MemScope::Cta => ( - spirv::Scope::Workgroup, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Gpu => ( - spirv::Scope::Device, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Sys => ( - spirv::Scope::CrossDevice, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - - ast::MemScope::Cluster => todo!(), - }; - let spirv_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope as u32), - )?; - let spirv_semantics = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics), - )?; - builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; - } - }, - Statement::LoadVar(details) => { - emit_load_var(builder, map, details)?; - } - Statement::StoreVar(details) => { - let dst_ptr = match details.member_index { - Some(index) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to( - details.typ.clone(), - spirv::StorageClass::Function, - ), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - builder.in_bounds_access_chain( - result_ptr_type.0, - None, - details.arg.src1.0, - [index_spirv.0].iter().copied(), - )? - } - None => details.arg.src1.0, - }; - builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; - } - Statement::RetValue(_, id) => { - builder.ret_value(id.0)?; - } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src, - }) => { - let u8_pointer = map.get_or_add( - builder, - SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), - ); - let result_type = map.get_or_add( - builder, - SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), - ); - let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; - let temp = builder.in_bounds_ptr_access_chain( - u8_pointer.0, - None, - ptr_src_u8, - offset_src.0, - iter::empty(), - )?; - builder.bitcast(result_type.0, Some(dst.0), temp)?; - } - Statement::RepackVector(repack) => { - if repack.is_extract { - let scalar_type = map.get_or_add_scalar(builder, repack.typ); - for (index, dst_id) in repack.unpacked.iter().enumerate() { - builder.composite_extract( - scalar_type.0, - Some(dst_id.0), - repack.packed.0, - [index as u32].iter().copied(), - )?; - } - } else { - let vector_type = map.get_or_add( - builder, - SpirvType::Vector( - SpirvScalarKey::from(repack.typ), - repack.unpacked.len() as u8, - ), - ); - let mut temp_vec = builder.undef(vector_type.0, None); - for (index, src_id) in repack.unpacked.iter().enumerate() { - temp_vec = builder.composite_insert( - vector_type.0, - None, - src_id.0, - temp_vec, - [index as u32].iter().copied(), - )?; - } - builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; - } - } - Statement::VectorAccess(vector_access) => todo!(), - } - } - Ok(()) -} - -fn emit_function_linkage<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - f: &Function, - fn_name: SpirvWord, -) -> Result<(), TranslateError> { - if f.linkage == ast::LinkingDirective::NONE { - return Ok(()); - }; - let linking_name = match f.func_decl.borrow().name { - // According to SPIR-V rules linkage attributes are invalid on kernels - ast::MethodName::Kernel(..) => return Ok(()), - ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( - || match id_defs.reverse_variables.get(&fn_id) { - Some(fn_name) => Ok(fn_name), - None => Err(error_unknown_symbol()), - }, - Result::Ok, - )?, - }; - emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); - Ok(()) -} - -fn get_function_type( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - spirv_input: impl Iterator, - spirv_output: &[ast::Variable], -) -> (SpirvWord, SpirvWord) { - map.get_or_add_fn( - builder, - spirv_input, - spirv_output - .iter() - .map(|var| SpirvType::new(var.v_type.clone())), - ) -} - -fn emit_linking_decoration<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - name_override: Option<&str>, - name: SpirvWord, - linking: ast::LinkingDirective, -) { - if linking == ast::LinkingDirective::NONE { - return; - } - if linking.contains(ast::LinkingDirective::VISIBLE) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name.0, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Export), - ] - .iter() - .cloned(), - ); - } else if linking.contains(ast::LinkingDirective::EXTERN) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name.0, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Import), - ] - .iter() - .cloned(), - ); - } - // TODO: handle LinkingDirective::WEAK -} - -fn effective_input_arguments<'a>( - this: &'a ast::MethodDeclaration<'a, SpirvWord>, -) -> impl Iterator + 'a { - let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); - this.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) -} - -fn emit_implicit_conversion( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - cv: &ImplicitConversion, -) -> Result<(), TranslateError> { - let from_parts = to_parts(&cv.from_type); - let to_parts = to_parts(&cv.to_type); - match (from_parts.kind, to_parts.kind, &cv.kind) { - (_, _, &ConversionKind::BitToPtr) => { - let dst_type = map.get_or_add( - builder, - SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), - ); - builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { - if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - if from_parts.scalar_kind != ast::ScalarKind::Float - && to_parts.scalar_kind != ast::ScalarKind::Float - { - // It is noop, but another instruction expects result of this conversion - builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } else { - builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - } else { - // This block is safe because it's illegal to implictly convert between floating point values - let same_width_bit_type = map.get_or_add( - builder, - SpirvType::new(type_from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - })), - ); - let same_width_bit_value = - builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; - let wide_bit_type = type_from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..to_parts - }); - let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); - if to_parts.scalar_kind == ast::ScalarKind::Unsigned - || to_parts.scalar_kind == ast::ScalarKind::Bit - { - builder.u_convert( - wide_bit_type_spirv.0, - Some(cv.dst.0), - same_width_bit_value, - )?; - } else { - let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed - && to_parts.scalar_kind == ast::ScalarKind::Signed - { - dr::Builder::s_convert - } else { - dr::Builder::u_convert - }; - let wide_bit_value = - conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; - emit_implicit_conversion( - builder, - map, - &ImplicitConversion { - src: SpirvWord(wide_bit_value), - dst: cv.dst, - from_type: wide_bit_type, - from_space: cv.from_space, - to_type: cv.to_type.clone(), - to_space: cv.to_space, - kind: ConversionKind::Default, - }, - )?; - } - } - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; - } - (_, _, &ConversionKind::PtrToPtr) => { - let result_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.to_space), - ), - ); - if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.from_space), - ), - ); - builder.bitcast(temp_type.0, None, cv.src.0)? - } else { - cv.src.0 - }; - builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; - } else if cv.from_space == ast::StateSpace::Generic - && cv.to_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.from_space), - ), - ); - builder.bitcast(temp_type.0, None, cv.src.0)? - } else { - cv.src.0 - }; - builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; - } else { - builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - } - (_, _, &ConversionKind::AddressOf) => { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - _ => unreachable!(), - } - Ok(()) -} - -fn vec_repr(t: T) -> Vec { - let mut result = vec![0; mem::size_of::()]; - unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; - result -} - -fn emit_abs( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - d: &ast::TypeFtz, - arg: &ast::AbsArgs, -) -> Result<(), dr::Error> { - let scalar_t = ast::ScalarType::from(d.type_); - let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { - spirv::CLOp::s_abs - } else { - spirv::CLOp::fabs - }; - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - cl_abs as spirv::Word, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - Ok(()) -} - -fn emit_mul_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MulArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(type_)); - match control { - ast::MulIntControl::Low => { - builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - } - ast::MulIntControl::High => { - let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { - spirv::CLOp::s_mul_hi - } else { - spirv::CLOp::u_mul_hi - }; - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - opencl_inst as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => { - let instr_width = type_.size_of(); - let instr_kind = type_.kind(); - let dst_type = scalar_from_parts(instr_width * 2, instr_kind); - let dst_type_id = map.get_or_add_scalar(builder, dst_type); - let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { - let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; - let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; - (src1, src2) - } else { - let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; - let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; - (src1, src2) - }; - builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; - builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); - } - } - Ok(()) -} - -fn emit_mul_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - ctr: &ast::ArithFloat, - arg: &ast::MulArgs, -) -> Result<(), dr::Error> { - if ctr.saturate { - todo!() - } - let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); - builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, ctr.rounding); - Ok(()) -} - -fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { - match kind { - ast::ScalarKind::Float => match width { - 2 => ast::ScalarType::F16, - 4 => ast::ScalarType::F32, - 8 => ast::ScalarType::F64, - _ => unreachable!(), - }, - ast::ScalarKind::Bit => match width { - 1 => ast::ScalarType::B8, - 2 => ast::ScalarType::B16, - 4 => ast::ScalarType::B32, - 8 => ast::ScalarType::B64, - _ => unreachable!(), - }, - ast::ScalarKind::Signed => match width { - 1 => ast::ScalarType::S8, - 2 => ast::ScalarType::S16, - 4 => ast::ScalarType::S32, - 8 => ast::ScalarType::S64, - _ => unreachable!(), - }, - ast::ScalarKind::Unsigned => match width { - 1 => ast::ScalarType::U8, - 2 => ast::ScalarType::U16, - 4 => ast::ScalarType::U32, - 8 => ast::ScalarType::U64, - _ => unreachable!(), - }, - ast::ScalarKind::Pred => ast::ScalarType::Pred, - } -} - -fn emit_rounding_decoration( - builder: &mut dr::Builder, - dst: SpirvWord, - rounding: Option, -) { - if let Some(rounding) = rounding { - builder.decorate( - dst.0, - spirv::Decoration::FPRoundingMode, - [rounding_to_spirv(rounding)].iter().cloned(), - ); - } -} - -fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { - let mode = match this { - ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, - ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, - ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, - ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, - }; - rspirv::dr::Operand::FPRoundingMode(mode) -} - -fn emit_add_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::AddArgs, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); - builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - Ok(()) -} - -fn emit_add_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::AddArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); - builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_setp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - setp: &ast::SetpData, - arg: &ast::SetpArgs, -) -> Result<(), dr::Error> { - let result_type = map - .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) - .0; - let result_id = Some(arg.dst1.0); - let operand_1 = arg.src1.0; - let operand_2 = arg.src2.0; - match setp.cmp_op { - ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { - builder.i_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { - builder.f_ord_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { - builder.i_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { - builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { - builder.u_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { - builder.s_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { - builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { - builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { - builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { - builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { - builder.u_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { - builder.s_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { - builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { - builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { - builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { - builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { - builder.f_unord_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { - builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { - builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { - builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { - builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { - builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - builder.logical_or(result_type, result_id, temp1, temp2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; - logical_not(builder, result_type, result_id, any_nan) - } - _ => todo!(), - }?; - Ok(()) -} - -// HACK ALERT -// Temporary workaround until IGC gets its shit together -// Currently IGC carries two copies of SPIRV-LLVM translator -// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. -// Obviously, old and buggy one is used for compiling L0 SPIRV -// https://github.com/intel/intel-graphics-compiler/issues/148 -fn logical_not( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - operand: spirv::Word, -) -> Result { - let const_true = builder.constant_true(result_type, None); - let const_false = builder.constant_false(result_type, None); - builder.select(result_type, result_id, operand, const_false, const_true) -} - -// HACK ALERT -// For some reason IGC fails linking if the value and shift size are of different type -fn insert_shift_hack( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - offset_var: spirv::Word, - size_of: usize, -) -> Result { - let result_type = match size_of { - 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), - 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), - 4 => return Ok(offset_var), - _ => return Err(error_unreachable()), - }; - Ok(builder.u_convert(result_type.0, None, offset_var)?) -} - -fn emit_cvt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - dets: &ast::CvtDetails, - arg: &ast::CvtArgs, -) -> Result<(), TranslateError> { - match dets.mode { - ptx_parser::CvtMode::SignExtend => { - let cv = ImplicitConversion { - src: arg.src, - dst: arg.dst, - from_type: dets.from.into(), - from_space: ast::StateSpace::Reg, - to_type: dets.to.into(), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::SignExtend, - }; - emit_implicit_conversion(builder, map, &cv)?; - } - ptx_parser::CvtMode::ZeroExtend - | ptx_parser::CvtMode::Truncate - | ptx_parser::CvtMode::Bitcast => { - let cv = ImplicitConversion { - src: arg.src, - dst: arg.dst, - from_type: dets.from.into(), - from_space: ast::StateSpace::Reg, - to_type: dets.to.into(), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::Default, - }; - emit_implicit_conversion(builder, map, &cv)?; - } - ptx_parser::CvtMode::SaturateUnsignedToSigned => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::SaturateSignedToUnsigned => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::FPExtend { flush_to_zero } => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::FPTruncate { - rounding, - flush_to_zero, - } => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPRound { - integer_rounding, - flush_to_zero, - } => { - if flush_to_zero == Some(true) { - todo!() - } - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - match integer_rounding { - Some(ast::RoundingMode::NearestEven) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::rint as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::Zero) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::trunc as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::NegativeInf) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::floor as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::PositiveInf) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::ceil as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - None => { - builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - } - } - ptx_parser::CvtMode::SignedFromFP { - rounding, - flush_to_zero, - } => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::UnsignedFromFP { - rounding, - flush_to_zero, - } => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPFromSigned(rounding) => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPFromUnsigned(rounding) => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - } - Ok(()) -} - -fn emit_mad_uint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) - .0; - match control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; - builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::u_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_mad_sint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; - match control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; - builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::s_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_mad_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::mad as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_fma_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::FmaArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::fma as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_sub_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::SubArgs, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) - .0; - builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - Ok(()) -} - -fn emit_sub_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::SubArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_min( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::MinArgs, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, - }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - cl_op as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_max( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::MaxArgs, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, - }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - cl_op as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_rcp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::RcpData, - arg: &ast::RcpArgs, -) -> Result<(), TranslateError> { - let is_f64 = desc.type_ == ast::ScalarType::F64; - let (instr_type, constant) = if is_f64 { - (ast::ScalarType::F64, vec_repr(1.0f64)) - } else { - (ast::ScalarType::F32, vec_repr(1.0f32)) - }; - let result_type = map.get_or_add_scalar(builder, instr_type); - let rounding = match desc.kind { - ptx_parser::RcpKind::Approx => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::native_recip as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - return Ok(()); - } - ptx_parser::RcpKind::Compliant(rounding) => rounding, - }; - let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; - builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - builder.decorate( - arg.dst.0, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - Ok(()) -} - -fn emit_atom( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &ast::AtomDetails, - arg: &ast::AtomArgs, -) -> Result<(), TranslateError> { - let spirv_op = match details.op { - ptx_parser::AtomicOp::And => dr::Builder::atomic_and, - ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, - ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, - ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, - ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, - ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { - return Err(error_unreachable()) - } - ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, - ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, - ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, - ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, - ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, - ptx_parser::AtomicOp::FloatMin => todo!(), - ptx_parser::AtomicOp::FloatMax => todo!(), - }; - let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope_to_spirv(details.scope) as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics_to_spirv(details.semantics).bits()), - )?; - spirv_op( - builder, - result_type.0, - Some(arg.dst.0), - arg.src1.0, - memory_const.0, - semantics_const.0, - arg.src2.0, - )?; - Ok(()) -} - -fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { - match this { - ast::MemScope::Cta => spirv::Scope::Workgroup, - ast::MemScope::Gpu => spirv::Scope::Device, - ast::MemScope::Sys => spirv::Scope::CrossDevice, - ptx_parser::MemScope::Cluster => todo!(), - } -} - -fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { - match this { - ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, - ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, - ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, - ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, - } -} - -fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { - match kind { - ast::DivFloatKind::Approx => { - builder.decorate( - dst.0, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - } - ast::DivFloatKind::Rounding(rnd) => { - emit_rounding_decoration(builder, dst, Some(rnd)); - } - ast::DivFloatKind::ApproxFull => {} - } -} - -fn emit_sqrt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - details: &ast::RcpData, - a: &ast::SqrtArgs, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add_scalar(builder, details.type_.into()); - let (ocl_op, rounding) = match details.kind { - ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), - ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), - }; - builder.ext_inst( - result_type.0, - Some(a.dst.0), - opencl, - ocl_op as spirv::Word, - [dr::Operand::IdRef(a.src.0)].iter().cloned(), - )?; - emit_rounding_decoration(builder, a.dst, rounding); - Ok(()) -} - -// TODO: check what kind of assembly do we emit -fn emit_logical_xor_spirv( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - op1: spirv::Word, - op2: spirv::Word, -) -> Result { - let temp_or = builder.logical_or(result_type, None, op1, op2)?; - let temp_and = builder.logical_and(result_type, None, op1, op2)?; - let temp_neg = logical_not(builder, result_type, None, temp_and)?; - builder.logical_and(result_type, result_id, temp_or, temp_neg) -} - -fn emit_load_var( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &LoadVarDetails, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); - match details.member_index { - Some((index, Some(width))) => { - let vector_type = match details.typ { - ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), - _ => return Err(error_mismatched_type()), - }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); - let vector_temp = builder.load( - vector_type_spirv.0, - None, - details.arg.src.0, - None, - iter::empty(), - )?; - builder.composite_extract( - result_type.0, - Some(details.arg.dst.0), - vector_temp, - [index as u32].iter().copied(), - )?; - } - Some((index, None)) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - let src = builder.in_bounds_access_chain( - result_ptr_type.0, - None, - details.arg.src.0, - [index_spirv.0].iter().copied(), - )?; - builder.load( - result_type.0, - Some(details.arg.dst.0), - src, - None, - iter::empty(), - )?; - } - None => { - builder.load( - result_type.0, - Some(details.arg.dst.0), - details.arg.src.0, - None, - iter::empty(), - )?; - } - }; - Ok(()) -} - -fn to_parts(this: &ast::Type) -> TypeParts { - match this { - ast::Type::Scalar(scalar) => TypeParts { - kind: TypeKind::Scalar, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - ast::Type::Vector(components, scalar) => TypeParts { - kind: TypeKind::Vector, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*components as u32], - }, - ast::Type::Array(_, scalar, components) => TypeParts { - kind: TypeKind::Array, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - }, - ast::Type::Pointer(scalar, space) => TypeParts { - kind: TypeKind::Pointer, - state_space: *space, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - } -} - -fn type_from_parts(t: TypeParts) -> ast::Type { - match t.kind { - TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), - TypeKind::Vector => ast::Type::Vector( - t.components[0] as u8, - scalar_from_parts(t.width, t.scalar_kind), - ), - TypeKind::Array => ast::Type::Array( - None, - scalar_from_parts(t.width, t.scalar_kind), - t.components, - ), - TypeKind::Pointer => { - ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) - } - } -} - -#[derive(Eq, PartialEq, Clone)] -struct TypeParts { - kind: TypeKind, - scalar_kind: ast::ScalarKind, - width: u8, - state_space: ast::StateSpace, - components: Vec, -} - -#[derive(Eq, PartialEq, Copy, Clone)] -enum TypeKind { - Scalar, - Vector, - Array, - Pointer, -} diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs deleted file mode 100644 index e496c75..0000000 --- a/ptx/src/pass/expand_arguments.rs +++ /dev/null @@ -1,181 +0,0 @@ -use super::*; -use ptx_parser as ast; - -pub(super) fn run<'a, 'b>( - func: Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func { - match s { - Statement::Label(id) => result.push(Statement::Label(id)), - Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), - Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), - Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), - Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Constant(c) => result.push(Statement::Constant(c)), - Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), - s => { - let (new_statement, post_stmts) = { - let mut visitor = FlattenArguments::new(&mut result, id_def); - (s.visit_map(&mut visitor)?, visitor.post_stmts) - }; - result.push(new_statement); - result.extend(post_stmts); - } - } - } - Ok(result) -} - -struct FlattenArguments<'a, 'b> { - func: &'b mut Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, - post_stmts: Vec, -} - -impl<'a, 'b> FlattenArguments<'a, 'b> { - fn new( - func: &'b mut Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, - ) -> Self { - FlattenArguments { - func, - id_def, - post_stmts: Vec::new(), - } - } - - fn reg(&mut self, name: SpirvWord) -> Result { - Ok(name) - } - - fn reg_offset( - &mut self, - reg: SpirvWord, - offset: i32, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - _is_dst: bool, - ) -> Result { - let (type_, state_space) = if let Some((type_, state_space)) = type_space { - (type_, state_space) - } else { - return Err(TranslateError::UntypedSymbol); - }; - if state_space == ast::StateSpace::Reg { - let (reg_type, reg_space) = self.id_def.get_typed(reg)?; - if reg_space != ast::StateSpace::Reg { - return Err(error_mismatched_type()); - } - let reg_scalar_type = match reg_type { - ast::Type::Scalar(underlying_type) => underlying_type, - _ => return Err(error_mismatched_type()), - }; - let id_constant_stmt = self - .id_def - .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: reg_scalar_type, - value: ast::ImmediateValue::S64(offset as i64), - })); - let arith_details = match reg_scalar_type.kind() { - ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { - type_: reg_scalar_type, - saturate: false, - }), - ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { - ast::ArithDetails::Integer(ast::ArithInteger { - type_: reg_scalar_type, - saturate: false, - }) - } - _ => return Err(error_unreachable()), - }; - let id_add_result = self.id_def.register_intermediate(reg_type, state_space); - self.func - .push(Statement::Instruction(ast::Instruction::Add { - data: arith_details, - arguments: ast::AddArgs { - dst: id_add_result, - src1: reg, - src2: id_constant_stmt, - }, - })); - Ok(id_add_result) - } else { - let id_constant_stmt = self.id_def.register_intermediate( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self - .id_def - .register_intermediate(type_.clone(), state_space); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: type_.clone(), - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - Ok(dst) - } - } - - fn immediate( - &mut self, - value: ast::ImmediateValue, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - ) -> Result { - let (scalar_t, state_space) = - if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { - (*scalar, state_space) - } else { - return Err(TranslateError::UntypedSymbol); - }; - let id = self - .id_def - .register_intermediate(ast::Type::Scalar(scalar_t), state_space); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value, - })); - Ok(id) - } -} - -impl<'a, 'b> ast::VisitorMap for FlattenArguments<'a, 'b> { - fn visit( - &mut self, - args: TypedOperand, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - match args { - TypedOperand::Reg(r) => self.reg(r), - TypedOperand::Imm(x) => self.immediate(x, type_space), - TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(reg, offset, type_space, is_dst) - } - TypedOperand::VecMember(..) => Err(error_unreachable()), - } - } - - fn visit_ident( - &mut self, - name: ::Ident, - _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - _is_dst: bool, - _relaxed_type_check: bool, - ) -> Result<::Ident, TranslateError> { - self.reg(name) - } -} diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index 3dabf40..f2de786 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_member( &mut self, - vector_src: SpirvWord, + vector_ident: SpirvWord, member: u8, _type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, ) -> Result { - if is_dst { - return Err(error_mismatched_type()); - } - let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? { + let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? { (ast::Type::Vector(vector_width, scalar_t), space) => { (*vector_width, *scalar_t, *space) } @@ -206,35 +203,46 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { let temporary = self .resolver .register_unnamed(Some((scalar_type.into(), space))); - self.result.push(Statement::VectorAccess(VectorAccess { - scalar_type, - vector_width, - dst: temporary, - src: vector_src, - member: member, - })); + if is_dst { + self.post_stmts.push(Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst: vector_ident, + vector_src: vector_ident, + scalar_src: temporary, + member, + })); + } else { + self.result.push(Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: temporary, + vector_src: vector_ident, + member, + })); + } Ok(temporary) } fn vec_pack( &mut self, - vecs: Vec, + vector_elements: Vec, type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { - let (scalar_t, state_space) = match type_space { - Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space), + let (width, scalar_t, state_space) = match type_space { + Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space), _ => return Err(error_mismatched_type()), }; - let temp_vec = self + let temporary_vector = self .resolver - .register_unnamed(Some((scalar_t.into(), state_space))); + .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space))); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, - packed: temp_vec, - unpacked: vecs, + packed: temporary_vector, + unpacked: vector_elements, relaxed_type_check, }); if is_dst { @@ -242,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { } else { self.result.push(statement); } - Ok(temp_vec) + Ok(temporary_vector) } } @@ -273,7 +281,7 @@ impl<'a, 'b> ast::VisitorMap, SpirvWord, Translate fn visit_ident( &mut self, - name: ::Ident, + name: SpirvWord, _type_space: Option<(&ast::Type, ast::StateSpace)>, _is_dst: bool, _relaxed_type_check: bool, diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs deleted file mode 100644 index 2912366..0000000 --- a/ptx/src/pass/extract_globals.rs +++ /dev/null @@ -1,281 +0,0 @@ -use super::*; - -pub(super) fn run<'input, 'b>( - sorted_statements: Vec, - ptx_impl_imports: &mut HashMap, - id_def: &mut NumericIdResolver, -) -> Result<(Vec, Vec>), TranslateError> { - let mut local = Vec::with_capacity(sorted_statements.len()); - let mut global = Vec::new(); - for statement in sorted_statements { - match statement { - Statement::Variable( - var @ ast::Variable { - state_space: ast::StateSpace::Shared, - .. - }, - ) - | Statement::Variable( - var @ ast::Variable { - state_space: ast::StateSpace::Global, - .. - }, - ) => global.push(var), - Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Bfe { data, arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Bfi { data, arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Brev { data, arguments }) => { - let fn_name: String = - [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Brev { data, arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Activemask { arguments }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Activemask { arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom { - data: - data @ ast::AtomDetails { - op: ast::AtomicOp::IncrementWrap, - semantics, - scope, - space, - .. - }, - arguments, - }) => { - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - semantics_to_ptx_name(semantics), - "_", - scope_to_ptx_name(scope), - "_", - space_to_ptx_name(space), - "_inc", - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom { data, arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom { - data: - data @ ast::AtomDetails { - op: ast::AtomicOp::DecrementWrap, - semantics, - scope, - space, - .. - }, - arguments, - }) => { - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - semantics_to_ptx_name(semantics), - "_", - scope_to_ptx_name(scope), - "_", - space_to_ptx_name(space), - "_dec", - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom { data, arguments }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom { - data: - data @ ast::AtomDetails { - op: ast::AtomicOp::FloatAdd, - semantics, - scope, - space, - .. - }, - arguments, - }) => { - let scalar_type = match data.type_ { - ptx_parser::Type::Scalar(scalar) => scalar, - _ => return Err(error_unreachable()), - }; - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - semantics_to_ptx_name(semantics), - "_", - scope_to_ptx_name(scope), - "_", - space_to_ptx_name(space), - "_add_", - scalar_to_ptx_name(scalar_type), - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom { data, arguments }, - fn_name, - )?); - } - s => local.push(s), - } - } - Ok((local, global)) -} - -fn instruction_to_fn_call( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - inst: ast::Instruction, - fn_name: String, -) -> Result { - let mut arguments = Vec::new(); - ast::visit_map(inst, &mut |operand, - type_space: Option<( - &ast::Type, - ast::StateSpace, - )>, - is_dst, - _| { - let (typ, space) = match type_space { - Some((typ, space)) => (typ.clone(), space), - None => return Err(error_unreachable()), - }; - arguments.push((operand, is_dst, typ, space)); - Ok(SpirvWord(0)) - })?; - let return_arguments_count = arguments - .iter() - .position(|(desc, is_dst, _, _)| !is_dst) - .unwrap_or(arguments.len()); - let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); - let fn_id = register_external_fn_call( - id_defs, - ptx_impl_imports, - fn_name, - return_arguments - .iter() - .map(|(_, _, typ, state)| (typ, *state)), - input_arguments - .iter() - .map(|(_, _, typ, state)| (typ, *state)), - )?; - Ok(Statement::Instruction(ast::Instruction::Call { - data: ast::CallDetails { - uniform: false, - return_arguments: return_arguments - .iter() - .map(|(_, _, typ, state)| (typ.clone(), *state)) - .collect::>(), - input_arguments: input_arguments - .iter() - .map(|(_, _, typ, state)| (typ.clone(), *state)) - .collect::>(), - }, - arguments: ast::CallArgs { - return_arguments: return_arguments - .iter() - .map(|(name, _, _, _)| *name) - .collect::>(), - func: fn_id, - input_arguments: input_arguments - .iter() - .map(|(name, _, _, _)| *name) - .collect::>(), - }, - })) -} - -fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { - match this { - ast::ScalarType::B8 => "b8", - ast::ScalarType::B16 => "b16", - ast::ScalarType::B32 => "b32", - ast::ScalarType::B64 => "b64", - ast::ScalarType::B128 => "b128", - ast::ScalarType::U8 => "u8", - ast::ScalarType::U16 => "u16", - ast::ScalarType::U16x2 => "u16x2", - ast::ScalarType::U32 => "u32", - ast::ScalarType::U64 => "u64", - ast::ScalarType::S8 => "s8", - ast::ScalarType::S16 => "s16", - ast::ScalarType::S16x2 => "s16x2", - ast::ScalarType::S32 => "s32", - ast::ScalarType::S64 => "s64", - ast::ScalarType::F16 => "f16", - ast::ScalarType::F16x2 => "f16x2", - ast::ScalarType::F32 => "f32", - ast::ScalarType::F64 => "f64", - ast::ScalarType::BF16 => "bf16", - ast::ScalarType::BF16x2 => "bf16x2", - ast::ScalarType::Pred => "pred", - } -} - -fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str { - match this { - ast::AtomSemantics::Relaxed => "relaxed", - ast::AtomSemantics::Acquire => "acquire", - ast::AtomSemantics::Release => "release", - ast::AtomSemantics::AcqRel => "acq_rel", - } -} - -fn scope_to_ptx_name(this: ast::MemScope) -> &'static str { - match this { - ast::MemScope::Cta => "cta", - ast::MemScope::Gpu => "gpu", - ast::MemScope::Sys => "sys", - ast::MemScope::Cluster => "cluster", - } -} - -fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { - match this { - ast::StateSpace::Generic => "generic", - ast::StateSpace::Global => "global", - ast::StateSpace::Shared => "shared", - ast::StateSpace::Reg => "reg", - ast::StateSpace::Const => "const", - ast::StateSpace::Local => "local", - ast::StateSpace::Param => "param", - ast::StateSpace::SharedCluster => "shared_cluster", - ast::StateSpace::ParamEntry => "param_entry", - ast::StateSpace::SharedCta => "shared_cta", - ast::StateSpace::ParamFunc => "param_func", - } -} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs deleted file mode 100644 index c029016..0000000 --- a/ptx/src/pass/fix_special_registers.rs +++ /dev/null @@ -1,130 +0,0 @@ -use super::*; -use std::collections::HashMap; - -pub(super) fn run<'a, 'b, 'input>( - ptx_impl_imports: &'a mut HashMap>, - typed_statements: Vec, - numeric_id_defs: &'a mut NumericIdResolver<'b>, -) -> Result, TranslateError> { - let result = Vec::with_capacity(typed_statements.len()); - let mut sreg_sresolver = SpecialRegisterResolver { - ptx_impl_imports, - numeric_id_defs, - result, - }; - for statement in typed_statements { - let statement = statement.visit_map(&mut sreg_sresolver)?; - sreg_sresolver.result.push(statement); - } - Ok(sreg_sresolver.result) -} - -struct SpecialRegisterResolver<'a, 'b, 'input> { - ptx_impl_imports: &'a mut HashMap>, - numeric_id_defs: &'a mut NumericIdResolver<'b>, - result: Vec, -} - -impl<'a, 'b, 'input> ast::VisitorMap - for SpecialRegisterResolver<'a, 'b, 'input> -{ - fn visit( - &mut self, - operand: TypedOperand, - _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index)) - } - - fn visit_ident( - &mut self, - args: SpirvWord, - _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - self.replace_sreg(args, is_dst, None) - } -} - -impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { - fn replace_sreg( - &mut self, - name: SpirvWord, - is_dst: bool, - vector_index: Option, - ) -> Result { - if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { - if is_dst { - return Err(error_mismatched_type()); - } - let input_arguments = match (vector_index, sreg.get_function_input_type()) { - (Some(idx), Some(inp_type)) => { - if inp_type != ast::ScalarType::U8 { - return Err(TranslateError::Unreachable); - } - let constant = self.numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(inp_type), - ast::StateSpace::Reg, - ))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: constant, - typ: inp_type, - value: ast::ImmediateValue::U64(idx as u64), - })); - vec![( - TypedOperand::Reg(constant), - ast::Type::Scalar(inp_type), - ast::StateSpace::Reg, - )] - } - (None, None) => Vec::new(), - _ => return Err(error_mismatched_type()), - }; - let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); - let return_type = sreg.get_function_return_type(); - let fn_result = self.numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(return_type), - ast::StateSpace::Reg, - ))); - let return_arguments = vec![( - fn_result, - ast::Type::Scalar(return_type), - ast::StateSpace::Reg, - )]; - let fn_call = register_external_fn_call( - self.numeric_id_defs, - self.ptx_impl_imports, - ocl_fn_name.to_string(), - return_arguments.iter().map(|(_, typ, space)| (typ, *space)), - input_arguments.iter().map(|(_, typ, space)| (typ, *space)), - )?; - let data = ast::CallDetails { - uniform: false, - return_arguments: return_arguments - .iter() - .map(|(_, typ, space)| (typ.clone(), *space)) - .collect(), - input_arguments: input_arguments - .iter() - .map(|(_, typ, space)| (typ.clone(), *space)) - .collect(), - }; - let arguments = ast::CallArgs { - return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), - func: fn_call, - input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(), - }; - self.result - .push(Statement::Instruction(ast::Instruction::Call { - data, - arguments, - })); - Ok(fn_result) - } else { - Ok(name) - } - } -} diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 97f6356..8c3b794 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>( sreg_to_function, result: Vec::new(), }; - directives - .into_iter() - .map(|directive| run_directive(&mut visitor, directive)) - .collect::, _>>() + for directive in directives.into_iter() { + result.push(run_directive(&mut visitor, directive)?); + } + Ok(result) } fn run_directive<'a, 'input>( @@ -112,7 +112,7 @@ impl<'a, 'b, 'input> is_dst: bool, _relaxed_type_check: bool, ) -> Result { - self.replace_sreg(args, None, is_dst) + Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args)) } } @@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { name: SpirvWord, vector_index: Option, is_dst: bool, - ) -> Result { + ) -> Result, TranslateError> { if let Some(sreg) = self.special_registers.get(name) { if is_dst { return Err(error_mismatched_type()); @@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { data, arguments, })); - Ok(fn_result) + Ok(Some(fn_result)) } else { - Ok(name) + Ok(None) } } } -pub fn map_operand( +pub fn map_operand( this: ast::ParsedOperand, - fn_: &mut impl FnMut(T, Option) -> Result, -) -> Result, Err> { + fn_: &mut impl FnMut(T, Option) -> Result, Err>, +) -> Result, Err> { Ok(match this { - ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?), + ast::ParsedOperand::Reg(ident) => { + ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident)) + } ast::ParsedOperand::RegOffset(ident, offset) => { - ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset) + ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset) } ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), - ast::ParsedOperand::VecMember(ident, member) => { - ast::ParsedOperand::Reg(fn_(ident, Some(member))?) - } + ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? { + Some(ident) => ast::ParsedOperand::Reg(ident), + None => ast::ParsedOperand::VecMember(ident, member), + }, ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( idents .into_iter() - .map(|ident| fn_(ident, None)) + .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident))) .collect::, _>>()?, ), }) diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 753172a..718c052 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -5,7 +5,7 @@ pub(super) fn run<'input>( ) -> Result, SpirvWord>>, TranslateError> { let mut result = Vec::with_capacity(directives.len()); for mut directive in directives.into_iter() { - run_directive(&mut result, &mut directive); + run_directive(&mut result, &mut directive)?; result.push(directive); } Ok(result) diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index ec6498c..60c4a14 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -1,7 +1,4 @@ use super::*; -use ptx_parser::VisitorMap; -use rustc_hash::FxHashSet; - // This pass: // * Turns all .local, .param and .reg in-body variables into .local variables // (if _not_ an input method argument) @@ -40,9 +37,6 @@ fn run_method<'a, 'input>( method: Function2<'input, ast::Instruction, SpirvWord>, ) -> Result, SpirvWord>, TranslateError> { let mut func_decl = method.func_decl; - for arg in func_decl.return_arguments.iter_mut() { - visitor.visit_variable(arg)?; - } let is_kernel = func_decl.name.is_kernel(); if is_kernel { for arg in func_decl.input_arguments.iter_mut() { @@ -52,17 +46,21 @@ fn run_method<'a, 'input>( let new_name = visitor .resolver .register_unnamed(Some((arg.v_type.clone(), new_space))); - visitor.input_argument(old_name, new_name, old_space); + visitor.input_argument(old_name, new_name, old_space)?; arg.name = new_name; arg.state_space = new_space; } }; + for arg in func_decl.return_arguments.iter_mut() { + visitor.visit_variable(arg)?; + } + let return_arguments = &func_decl.return_arguments[..]; let body = method .body .map(move |statements| { let mut result = Vec::with_capacity(statements.len()); for statement in statements { - run_statement(&mut visitor, &mut result, statement)?; + run_statement(&mut visitor, return_arguments, &mut result, statement)?; } Ok::<_, TranslateError>(result) }) @@ -79,10 +77,33 @@ fn run_method<'a, 'input>( fn run_statement<'a, 'input>( visitor: &mut InsertMemSSAVisitor<'a, 'input>, + return_arguments: &[ast::Variable], result: &mut Vec, statement: ExpandedStatement, ) -> Result<(), TranslateError> { match statement { + Statement::Instruction(ast::Instruction::Ret { data }) => { + let statement = if return_arguments.is_empty() { + Statement::Instruction(ast::Instruction::Ret { data }) + } else { + Statement::RetValue( + data, + return_arguments + .iter() + .map(|arg| { + if arg.state_space != ast::StateSpace::Local { + return Err(error_unreachable()); + } + Ok((arg.name, arg.v_type.clone())) + }) + .collect::, _>>()?, + ) + }; + let new_statement = statement.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(new_statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } Statement::Variable(mut var) => { visitor.visit_variable(&mut var)?; result.push(Statement::Variable(var)); @@ -154,7 +175,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { old_name: SpirvWord, new_name: SpirvWord, old_space: ast::StateSpace, - ) -> Result<(), TranslateError> { + ) -> Result { Ok(match old_space { ast::StateSpace::Reg => { self.variables.insert( @@ -164,6 +185,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { type_: type_.clone(), }, ); + true } ast::StateSpace::Param => { self.variables.insert( @@ -174,19 +196,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { name: new_name, }, ); + true } // Good as-is - ast::StateSpace::Local => {} - // Will be pulled into global scope later - ast::StateSpace::Generic + ast::StateSpace::Local + | ast::StateSpace::Generic | ast::StateSpace::SharedCluster | ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::SharedCta - | ast::StateSpace::Shared => {} - ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => { - return Err(error_unreachable()) - } + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => return Err(error_unreachable()), }) } @@ -239,17 +260,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { - if var.state_space != ast::StateSpace::Local { - let old_name = var.name; - let old_space = var.state_space; - let new_space = ast::StateSpace::Local; - let new_name = self - .resolver - .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(&var.v_type, old_name, new_name, old_space)?; - var.name = new_name; - var.state_space = new_space; - } + let old_space = match var.state_space { + space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, + // Do nothing + ptx_parser::StateSpace::Local => return Ok(()), + // Handled by another pass + ptx_parser::StateSpace::Generic + | ptx_parser::StateSpace::SharedCluster + | ptx_parser::StateSpace::ParamEntry + | ptx_parser::StateSpace::Global + | ptx_parser::StateSpace::SharedCta + | ptx_parser::StateSpace::Const + | ptx_parser::StateSpace::Shared + | ptx_parser::StateSpace::ParamFunc => return Ok(()), + }; + let old_name = var.name; + let new_space = ast::StateSpace::Local; + let new_name = self + .resolver + .register_unnamed(Some((var.v_type.clone(), new_space))); + self.variable(&var.v_type, old_name, new_name, old_space)?; + var.name = new_name; + var.state_space = new_space; Ok(()) } } @@ -260,9 +292,9 @@ impl<'a, 'input> ast::VisitorMap fn visit( &mut self, ident: SpirvWord, - type_space: Option<(&ast::Type, ast::StateSpace)>, + _type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, - relaxed_type_check: bool, + _relaxed_type_check: bool, ) -> Result { if let Some(remap) = self.variables.get(&ident) { match remap { diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs deleted file mode 100644 index c04fa09..0000000 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ /dev/null @@ -1,438 +0,0 @@ -use std::mem; - -use super::*; -use ptx_parser as ast; - -/* - There are several kinds of implicit conversions in PTX: - * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands - * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size - - ld.param: not documented, but for instruction `ld.param. x, [y]`, - semantics are to first zext/chop/bitcast `y` as needed and then do - documented special ld/st/cvt conversion rules for destination operands - - st.param [x] y (used as function return arguments) same rule as above applies - - generic/global ld: for instruction `ld x, [y]`, y must be of type - b64/u64/s64, which is bitcast to a pointer, dereferenced and then - documented special ld/st/cvt conversion rules are applied to dst - - generic/global st: for instruction `st [x], y`, x must be of type - b64/u64/s64, which is bitcast to a pointer -*/ -pub(super) fn run( - func: Vec, - id_def: &mut MutableNumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func.into_iter() { - match s { - Statement::Instruction(inst) => { - insert_implicit_conversions_impl( - &mut result, - id_def, - Statement::Instruction(inst), - )?; - } - Statement::PtrAccess(access) => { - insert_implicit_conversions_impl( - &mut result, - id_def, - Statement::PtrAccess(access), - )?; - } - Statement::RepackVector(repack) => { - insert_implicit_conversions_impl( - &mut result, - id_def, - Statement::RepackVector(repack), - )?; - } - Statement::VectorAccess(vector_access) => { - insert_implicit_conversions_impl( - &mut result, - id_def, - Statement::VectorAccess(vector_access), - )?; - } - s @ Statement::Conditional(_) - | s @ Statement::Conversion(_) - | s @ Statement::Label(_) - | s @ Statement::Constant(_) - | s @ Statement::Variable(_) - | s @ Statement::LoadVar(..) - | s @ Statement::StoreVar(..) - | s @ Statement::RetValue(..) - | s @ Statement::FunctionPointer(..) => result.push(s), - } - } - Ok(result) -} - -fn insert_implicit_conversions_impl( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - stmt: ExpandedStatement, -) -> Result<(), TranslateError> { - let mut post_conv = Vec::new(); - let statement = stmt.visit_map::( - &mut |operand, - type_state: Option<(&ast::Type, ast::StateSpace)>, - is_dst, - relaxed_type_check| { - let (instr_type, instruction_space) = match type_state { - None => return Ok(operand), - Some(t) => t, - }; - let (operand_type, operand_space) = id_def.get_typed(operand)?; - let conversion_fn = if relaxed_type_check { - if is_dst { - should_convert_relaxed_dst_wrapper - } else { - should_convert_relaxed_src_wrapper - } - } else { - default_implicit_conversion - }; - match conversion_fn( - (operand_space, &operand_type), - (instruction_space, instr_type), - )? { - Some(conv_kind) => { - let conv_output = if is_dst { &mut post_conv } else { &mut *func }; - let mut from_type = instr_type.clone(); - let mut from_space = instruction_space; - let mut to_type = operand_type; - let mut to_space = operand_space; - let mut src = - id_def.register_intermediate(instr_type.clone(), instruction_space); - let mut dst = operand; - let result = Ok::<_, TranslateError>(src); - if !is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from_type, &mut to_type); - mem::swap(&mut from_space, &mut to_space); - } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - from_space, - to_type, - to_space, - kind: conv_kind, - })); - result - } - None => Ok(operand), - } - }, - )?; - func.push(statement); - func.append(&mut post_conv); - Ok(()) -} - -pub(crate) fn default_implicit_conversion( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if instruction_space == ast::StateSpace::Reg { - if operand_space == ast::StateSpace::Reg { - if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = - (operand_type, instruction_type) - { - if scalar.kind() == ast::ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) - { - return Ok(Some(ConversionKind::Default)); - } - } - } else if is_addressable(operand_space) { - return Ok(Some(ConversionKind::AddressOf)); - } - } - if instruction_space != operand_space { - default_implicit_conversion_space( - (operand_space, operand_type), - (instruction_space, instruction_type), - ) - } else if instruction_type != operand_type { - default_implicit_conversion_type(instruction_space, operand_type, instruction_type) - } else { - Ok(None) - } -} - -fn is_addressable(this: ast::StateSpace) -> bool { - match this { - ast::StateSpace::Const - | ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Param | ast::StateSpace::Reg => false, - ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta - | ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc => todo!(), - } -} - -// Space is different -fn default_implicit_conversion_space( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) - || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) - { - Ok(Some(ConversionKind::PtrToPtr)) - } else if operand_space == ast::StateSpace::Reg { - match operand_type { - ast::Type::Pointer(operand_ptr_type, operand_ptr_space) - if *operand_ptr_space == instruction_space => - { - if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { - Ok(Some(ConversionKind::PtrToPtr)) - } else { - Ok(None) - } - } - // TODO: 32 bit - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { - ast::StateSpace::Global - | ast::StateSpace::Generic - | ast::StateSpace::Const - | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), - _ => Err(error_mismatched_type()), - }, - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { - Ok(Some(ConversionKind::BitToPtr)) - } - _ => Err(error_mismatched_type()), - }, - _ => Err(error_mismatched_type()), - } - } else if instruction_space == ast::StateSpace::Reg { - match instruction_type { - ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) - if operand_space == *instruction_ptr_space => - { - if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { - Ok(Some(ConversionKind::PtrToPtr)) - } else { - Ok(None) - } - } - _ => Err(error_mismatched_type()), - } - } else { - Err(error_mismatched_type()) - } -} - -// Space is same, but type is different -fn default_implicit_conversion_type( - space: ast::StateSpace, - operand_type: &ast::Type, - instruction_type: &ast::Type, -) -> Result, TranslateError> { - if space == ast::StateSpace::Reg { - if should_bitcast(instruction_type, operand_type) { - Ok(Some(ConversionKind::Default)) - } else { - Err(TranslateError::MismatchedType) - } - } else { - Ok(Some(ConversionKind::PtrToPtr)) - } -} - -fn coerces_to_generic(this: ast::StateSpace) -> bool { - match this { - ast::StateSpace::Global - | ast::StateSpace::Const - | ast::StateSpace::Local - | ptx_parser::StateSpace::SharedCta - | ast::StateSpace::SharedCluster - | ast::StateSpace::Shared => true, - ast::StateSpace::Reg - | ast::StateSpace::Param - | ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::Generic => false, - } -} - -fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { - match (instr, operand) { - (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { - if inst.size_of() != operand.size_of() { - return false; - } - match inst.kind() { - ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, - ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, - ast::ScalarKind::Signed => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Unsigned - } - ast::ScalarKind::Unsigned => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Signed - } - ast::ScalarKind::Pred => false, - } - } - (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) - | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { - should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) - } - _ => false, - } -} - -pub(crate) fn should_convert_relaxed_dst_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if operand_space != instruction_space { - return Err(TranslateError::MismatchedType); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_dst(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands -fn should_convert_relaxed_dst( - dst_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if dst_type == instr_type { - return None; - } - match (dst_type, instr_type) { - (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= dst_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed => { - if dst_type.kind() != ast::ScalarKind::Float { - if instr_type.size_of() == dst_type.size_of() { - Some(ConversionKind::Default) - } else if instr_type.size_of() < dst_type.size_of() { - Some(ConversionKind::SignExtend) - } else { - None - } - } else { - None - } - } - ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) - | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { - should_convert_relaxed_dst( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} - -pub(crate) fn should_convert_relaxed_src_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if operand_space != instruction_space { - return Err(error_mismatched_type()); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_src(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(error_mismatched_type()), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands -fn should_convert_relaxed_src( - src_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if src_type == instr_type { - return None; - } - match (src_type, instr_type) { - (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= src_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) - | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { - should_convert_relaxed_src( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs deleted file mode 100644 index 150109b..0000000 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ /dev/null @@ -1,275 +0,0 @@ -use super::*; -use ptx_parser as ast; - -/* - How do we handle arguments: - - input .params in kernels - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - We do this for two reasons. One, common treatment for argument-declared - .param variables and .param variables inside function (we assume that - at SPIR-V level every .param is a pointer in Function storage class) - - input .params in functions - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %_ptr_Function_ulong - - input .regs - .reg .b64 in_arg - get turned into the same SPIR-V as kernel .params: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - - output .regs - .reg .b64 out_arg - get just a variable declaration: - %2 = OpVariable %%_ptr_Function_ulong Function - - output .params don't exist, they have been moved to input positions - by an earlier pass - Distinguishing betweem kernel .params and function .params is not the - cleanest solution. Alternatively, we could "deparamize" all kernel .param - arguments by turning them into .reg arguments like this: - .param .b64 arg -> .reg ptr<.b64,.param> arg - This has the massive downside that this transformation would have to run - very early and would muddy up already difficult code. It's simpler to just - have an if here -*/ -pub(super) fn run<'a, 'b>( - func: Vec, - id_def: &mut NumericIdResolver, - fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.input_arguments.iter_mut() { - insert_mem_ssa_argument( - id_def, - &mut result, - arg, - matches!(fn_decl.name, ast::MethodName::Kernel(_)), - ); - } - for arg in fn_decl.return_arguments.iter() { - insert_mem_ssa_argument_reg_return(&mut result, arg); - } - for s in func { - match s { - Statement::Instruction(inst) => match inst { - ast::Instruction::Ret { data } => { - // TODO: handle multiple output args - match &fn_decl.return_arguments[..] { - [return_reg] => { - let new_id = id_def.register_intermediate(Some(( - return_reg.v_type.clone(), - ast::StateSpace::Reg, - ))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::LdArgs { - dst: new_id, - src: return_reg.name, - }, - typ: return_reg.v_type.clone(), - member_index: None, - })); - result.push(Statement::RetValue(data, new_id)); - } - [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })), - _ => unimplemented!(), - } - } - inst => insert_mem_ssa_statement_default( - id_def, - &mut result, - Statement::Instruction(inst), - )?, - }, - Statement::Conditional(bra) => { - insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))? - } - Statement::Conversion(conv) => { - insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))? - } - Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default( - id_def, - &mut result, - Statement::PtrAccess(ptr_access), - )?, - Statement::RepackVector(repack) => insert_mem_ssa_statement_default( - id_def, - &mut result, - Statement::RepackVector(repack), - )?, - Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default( - id_def, - &mut result, - Statement::FunctionPointer(func_ptr), - )?, - s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { - result.push(s) - } - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - -fn insert_mem_ssa_argument( - id_def: &mut NumericIdResolver, - func: &mut Vec, - arg: &mut ast::Variable, - is_kernel: bool, -) { - if !is_kernel && arg.state_space == ast::StateSpace::Param { - return; - } - let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); - func.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: ast::StateSpace::Reg, - name: arg.name, - array_init: Vec::new(), - })); - func.push(Statement::StoreVar(StoreVarDetails { - arg: ast::StArgs { - src1: arg.name, - src2: new_id, - }, - typ: arg.v_type.clone(), - member_index: None, - })); - arg.name = new_id; -} - -fn insert_mem_ssa_argument_reg_return( - func: &mut Vec, - arg: &ast::Variable, -) { - func.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: arg.array_init.clone(), - })); -} - -fn insert_mem_ssa_statement_default<'a, 'input>( - id_def: &'a mut NumericIdResolver<'input>, - func: &'a mut Vec, - stmt: TypedStatement, -) -> Result<(), TranslateError> { - let mut visitor = InsertMemSSAVisitor { - id_def, - func, - post_statements: Vec::new(), - }; - let new_stmt = stmt.visit_map(&mut visitor)?; - visitor.func.push(new_stmt); - visitor.func.extend(visitor.post_statements); - Ok(()) -} - -struct InsertMemSSAVisitor<'a, 'input> { - id_def: &'a mut NumericIdResolver<'input>, - func: &'a mut Vec, - post_statements: Vec, -} - -impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { - fn symbol( - &mut self, - symbol: SpirvWord, - member_index: Option, - expected: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - ) -> Result { - if expected.is_none() { - return Ok(symbol); - }; - let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; - if var_space != ast::StateSpace::Reg || !is_variable { - return Ok(symbol); - }; - let member_index = match member_index { - Some(idx) => { - let vector_width = match var_type { - ast::Type::Vector(width, scalar_t) => { - var_type = ast::Type::Scalar(scalar_t); - width - } - _ => return Err(error_mismatched_type()), - }; - Some(( - idx, - if self.id_def.special_registers.get(symbol).is_some() { - Some(vector_width) - } else { - None - }, - )) - } - None => None, - }; - let generated_id = self - .id_def - .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); - if !is_dst { - self.func.push(Statement::LoadVar(LoadVarDetails { - arg: ast::LdArgs { - dst: generated_id, - src: symbol, - }, - typ: var_type, - member_index, - })); - } else { - self.post_statements - .push(Statement::StoreVar(StoreVarDetails { - arg: ast::StArgs { - src1: symbol, - src2: generated_id, - }, - typ: var_type, - member_index: member_index.map(|(idx, _)| idx), - })); - } - Ok(generated_id) - } -} - -impl<'a, 'input> ast::VisitorMap - for InsertMemSSAVisitor<'a, 'input> -{ - fn visit( - &mut self, - operand: TypedOperand, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - _relaxed_type_check: bool, - ) -> Result { - Ok(match operand { - TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?) - } - TypedOperand::RegOffset(reg, offset) => { - TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset) - } - op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => { - TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?) - } - }) - } - - fn visit_ident( - &mut self, - args: SpirvWord, - type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - self.symbol(args, None, type_space, is_dst) - } -} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0e233ed..ef131b4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,84 +1,43 @@ use ptx_parser as ast; -use rspirv::{binary::Assemble, dr}; +use quick_error::quick_error; use rustc_hash::FxHashMap; use std::hash::Hash; -use std::num::NonZeroU8; use std::{ borrow::Cow, - cell::RefCell, - collections::{hash_map, HashMap, HashSet}, + collections::{hash_map, HashMap}, ffi::CString, iter, - marker::PhantomData, - mem, - rc::Rc, }; use strum::IntoEnumIterator; use strum_macros::EnumIter; -mod convert_dynamic_shared_memory_usage; -mod convert_to_stateful_memory_access; -mod convert_to_typed; mod deparamize_functions; pub(crate) mod emit_llvm; -mod emit_spirv; -mod expand_arguments; mod expand_operands; -mod extract_globals; -mod fix_special_registers; mod fix_special_registers2; mod hoist_globals; mod insert_explicit_load_store; -mod insert_implicit_conversions; mod insert_implicit_conversions2; -mod insert_mem_ssa_statements; -mod normalize_identifiers; mod normalize_identifiers2; -mod normalize_labels; -mod normalize_predicates; mod normalize_predicates2; +mod replace_instructions_with_function_calls; mod resolve_function_pointers; -static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); -static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; -pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { - let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); - let mut ptx_impl_imports = HashMap::new(); - let directives = ast - .directives - .into_iter() - .filter_map(|directive| { - translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() - }) - .collect::, _>>()?; - let directives = hoist_function_globals(directives); - let must_link_ptx_impl = ptx_impl_imports.len() > 0; - let mut directives = ptx_impl_imports - .into_iter() - .map(|(_, v)| v) - .chain(directives.into_iter()) - .collect::>(); - let mut builder = dr::Builder::new(); - builder.reserve_ids(id_defs.current_id().0); - let call_map = MethodsCallMap::new(&directives); - let mut directives = - convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || { - SpirvWord(builder.id()) - })?; - normalize_variable_decls(&mut directives); - let denorm_information = compute_denorm_information(&directives); - todo!() - /* - let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?; - Ok(Module { - llvm_ir, - kernel_info: HashMap::new(), - }) */ +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol {} + UntypedSymbol {} + MismatchedType {} + Unreachable {} + Todo {} + } } -pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result { +pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; @@ -86,11 +45,11 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result, SpirvWord>> = - expand_operands::run(&mut flat_resolver, directives)?; + let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; let directives = hoist_globals::run(directives)?; let llvm_ir = emit_llvm::run(flat_resolver, directives)?; Ok(Module { @@ -99,254 +58,15 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result( - id_defs: &'a mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &'a mut HashMap>, - d: ast::Directive<'input, ast::ParsedOperand<&'input str>>, -) -> Result>, TranslateError> { - Ok(match d { - ast::Directive::Variable(linking, var) => Some(Directive::Variable( - linking, - ast::Variable { - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), - array_init: var.array_init, - }, - )), - ast::Directive::Method(linkage, f) => { - translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method) - } - }) -} - -type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement>>; - -fn translate_function<'input, 'a>( - id_defs: &'a mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &'a mut HashMap>, - linkage: ast::LinkingDirective, - f: ParsedFunction<'input>, -) -> Result>, TranslateError> { - let import_as = match &f.func_directive { - ast::MethodDeclaration { - name: ast::MethodName::Func(func_name), - .. - } if *func_name == "__assertfail" || *func_name == "vprintf" => { - Some([ZLUDA_PTX_PREFIX, func_name].concat()) - } - _ => None, - }; - let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; - let mut func = to_ssa( - ptx_impl_imports, - str_resolver, - fn_resolver, - fn_decl, - f.body, - f.tuning, - linkage, - )?; - func.import_as = import_as; - if func.import_as.is_some() { - ptx_impl_imports.insert( - func.import_as.as_ref().unwrap().clone(), - Directive::Method(func), - ); - Ok(None) - } else { - Ok(Some(func)) - } -} - -fn to_ssa<'input, 'b>( - ptx_impl_imports: &'b mut HashMap>, - mut id_defs: FnStringIdResolver<'input, 'b>, - fn_defs: GlobalFnDeclResolver<'input, 'b>, - func_decl: Rc>>, - f_body: Option>>>, - tuning: Vec, - linkage: ast::LinkingDirective, -) -> Result, TranslateError> { - //deparamize_function_decl(&func_decl)?; - let f_body = match f_body { - Some(vec) => vec, - None => { - return Ok(Function { - func_decl: func_decl, - body: None, - globals: Vec::new(), - import_as: None, - tuning, - linkage, - }) - } - }; - let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?; - let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; - let typed_statements = - convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; - let (func_decl, typed_statements) = - convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = insert_mem_ssa_statements::run( - typed_statements, - &mut numeric_id_defs, - &mut (*func_decl).borrow_mut(), - )?; - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; - let expanded_statements = - insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?; - let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs); - let (f_body, globals) = - extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; - Ok(Function { - func_decl: func_decl, - globals: globals, - body: Some(f_body), - import_as: None, - tuning, - linkage, - }) -} - pub struct Module { pub llvm_ir: emit_llvm::MemoryBuffer, pub kernel_info: HashMap, } -struct GlobalStringIdResolver<'input> { - current_id: SpirvWord, - variables: HashMap, SpirvWord>, - pub(crate) reverse_variables: HashMap, - variables_type_check: HashMap>, - special_registers: SpecialRegistersMap, - fns: HashMap>, -} - -impl<'input> GlobalStringIdResolver<'input> { - fn new(start_id: SpirvWord) -> Self { - Self { - current_id: start_id, - variables: HashMap::new(), - reverse_variables: HashMap::new(), - variables_type_check: HashMap::new(), - special_registers: SpecialRegistersMap::new(), - fns: HashMap::new(), - } - } - - fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord { - self.get_or_add_impl(id, None) - } - - fn get_or_add_def_typed( - &mut self, - id: &'input str, - typ: ast::Type, - state_space: ast::StateSpace, - is_variable: bool, - ) -> SpirvWord { - self.get_or_add_impl(id, Some((typ, state_space, is_variable))) - } - - fn get_or_add_impl( - &mut self, - id: &'input str, - typ: Option<(ast::Type, ast::StateSpace, bool)>, - ) -> SpirvWord { - let id = match self.variables.entry(Cow::Borrowed(id)) { - hash_map::Entry::Occupied(e) => *(e.get()), - hash_map::Entry::Vacant(e) => { - let numeric_id = self.current_id; - e.insert(numeric_id); - self.reverse_variables.insert(numeric_id, id); - self.current_id.0 += 1; - numeric_id - } - }; - self.variables_type_check.insert(id, typ); - id - } - - fn get_id(&self, id: &str) -> Result { - self.variables - .get(id) - .copied() - .ok_or_else(error_unknown_symbol) +impl Module { + pub fn linked_bitcode(&self) -> &[u8] { + ZLUDA_PTX_IMPL } - - fn current_id(&self) -> SpirvWord { - self.current_id - } - - fn start_fn<'b>( - &'b mut self, - header: &'b ast::MethodDeclaration<'input, &'input str>, - ) -> Result< - ( - FnStringIdResolver<'input, 'b>, - GlobalFnDeclResolver<'input, 'b>, - Rc>>, - ), - TranslateError, - > { - // In case a function decl was inserted earlier we want to use its id - let name_id = self.get_or_add_def(header.name()); - let mut fn_resolver = FnStringIdResolver { - current_id: &mut self.current_id, - global_variables: &self.variables, - global_type_check: &self.variables_type_check, - special_registers: &mut self.special_registers, - variables: vec![HashMap::new(); 1], - type_check: HashMap::new(), - }; - let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); - let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); - let name = match header.name { - ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), - ast::MethodName::Func(_) => ast::MethodName::Func(name_id), - }; - let fn_decl = ast::MethodDeclaration { - return_arguments, - name, - input_arguments, - shared_mem: None, - }; - let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) { - let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); - let new_fn_decl = resolver.func_decl.clone(); - self.fns.insert(name_id, resolver); - new_fn_decl - } else { - Rc::new(RefCell::new(fn_decl)) - }; - Ok(( - fn_resolver, - GlobalFnDeclResolver { fns: &self.fns }, - new_fn_decl, - )) - } -} - -fn rename_fn_params<'a, 'b>( - fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: &'b [ast::Variable<&'a str>], -) -> Vec> { - args.iter() - .map(|a| ast::Variable { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), - v_type: a.v_type.clone(), - state_space: a.state_space, - align: a.align, - array_init: a.array_init.clone(), - }) - .collect() } pub struct KernelInfo { @@ -365,18 +85,6 @@ enum PtxSpecialRegister { } impl PtxSpecialRegister { - fn try_parse(s: &str) -> Option { - match s { - "%tid" => Some(Self::Tid), - "%ntid" => Some(Self::Ntid), - "%ctaid" => Some(Self::Ctaid), - "%nctaid" => Some(Self::Nctaid), - "%clock" => Some(Self::Clock), - "%lanemask_lt" => Some(Self::LanemaskLt), - _ => None, - } - } - fn as_str(self) -> &'static str { match self { Self::Tid => "%tid", @@ -431,216 +139,24 @@ impl PtxSpecialRegister { } } -struct SpecialRegistersMap { - reg_to_id: HashMap, - id_to_reg: HashMap, -} - -impl SpecialRegistersMap { - fn new() -> Self { - SpecialRegistersMap { - reg_to_id: HashMap::new(), - id_to_reg: HashMap::new(), - } - } - - fn get(&self, id: SpirvWord) -> Option { - self.id_to_reg.get(&id).copied() - } - - fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { - match self.reg_to_id.entry(reg) { - hash_map::Entry::Occupied(e) => *e.get(), - hash_map::Entry::Vacant(e) => { - let numeric_id = SpirvWord(current_id.0); - current_id.0 += 1; - e.insert(numeric_id); - self.id_to_reg.insert(numeric_id, reg); - numeric_id - } - } - } -} - -struct FnStringIdResolver<'input, 'b> { - current_id: &'b mut SpirvWord, - global_variables: &'b HashMap, SpirvWord>, - global_type_check: &'b HashMap>, - special_registers: &'b mut SpecialRegistersMap, - variables: Vec, SpirvWord>>, - type_check: HashMap>, -} - -impl<'a, 'b> FnStringIdResolver<'a, 'b> { - fn finish(self) -> NumericIdResolver<'b> { - NumericIdResolver { - current_id: self.current_id, - global_type_check: self.global_type_check, - type_check: self.type_check, - special_registers: self.special_registers, - } - } - - fn start_block(&mut self) { - self.variables.push(HashMap::new()) - } - - fn end_block(&mut self) { - self.variables.pop(); - } - - fn get_id(&mut self, id: &str) -> Result { - for scope in self.variables.iter().rev() { - match scope.get(id) { - Some(id) => return Ok(*id), - None => continue, - } - } - match self.global_variables.get(id) { - Some(id) => Ok(*id), - None => { - let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?; - Ok(self.special_registers.get_or_add(self.current_id, sreg)) - } - } - } - - fn add_def( - &mut self, - id: &'a str, - typ: Option<(ast::Type, ast::StateSpace)>, - is_variable: bool, - ) -> SpirvWord { - let numeric_id = *self.current_id; - self.variables - .last_mut() - .unwrap() - .insert(Cow::Borrowed(id), numeric_id); - self.type_check.insert( - numeric_id, - typ.map(|(typ, space)| (typ, space, is_variable)), - ); - self.current_id.0 += 1; - numeric_id - } - - #[must_use] - fn add_defs( - &mut self, - base_id: &'a str, - count: u32, - typ: ast::Type, - state_space: ast::StateSpace, - is_variable: bool, - ) -> impl Iterator { - let numeric_id = *self.current_id; - for i in 0..count { - self.variables.last_mut().unwrap().insert( - Cow::Owned(format!("{}{}", base_id, i)), - SpirvWord(numeric_id.0 + i), - ); - self.type_check.insert( - SpirvWord(numeric_id.0 + i), - Some((typ.clone(), state_space, is_variable)), - ); - } - self.current_id.0 += count; - (0..count) - .into_iter() - .map(move |i| SpirvWord(i + numeric_id.0)) - } -} - -struct NumericIdResolver<'b> { - current_id: &'b mut SpirvWord, - global_type_check: &'b HashMap>, - type_check: HashMap>, - special_registers: &'b mut SpecialRegistersMap, -} - -impl<'b> NumericIdResolver<'b> { - fn finish(self) -> MutableNumericIdResolver<'b> { - MutableNumericIdResolver { base: self } - } - - fn get_typed( - &self, - id: SpirvWord, - ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { - match self.type_check.get(&id) { - Some(Some(x)) => Ok(x.clone()), - Some(None) => Err(TranslateError::UntypedSymbol), - None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)), - None => match self.global_type_check.get(&id) { - Some(Some(result)) => Ok(result.clone()), - Some(None) | None => Err(TranslateError::UntypedSymbol), - }, - }, - } - } - - // This is for identifiers which will be emitted later as OpVariable - // They are candidates for insertion of LoadVar/StoreVar - fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { - let new_id = *self.current_id; - self.type_check - .insert(new_id, Some((typ, state_space, true))); - self.current_id.0 += 1; - new_id - } - - fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { - let new_id = *self.current_id; - self.type_check - .insert(new_id, typ.map(|(t, space)| (t, space, false))); - self.current_id.0 += 1; - new_id - } -} - -struct MutableNumericIdResolver<'b> { - base: NumericIdResolver<'b>, -} - -impl<'b> MutableNumericIdResolver<'b> { - fn unmut(self) -> NumericIdResolver<'b> { - self.base - } - - fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> { - self.base.get_typed(id).map(|(t, space, _)| (t, space)) - } - - fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { - self.base.register_intermediate(Some((typ, state_space))) - } +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() } -quick_error! { - #[derive(Debug)] - pub enum TranslateError { - UnknownSymbol {} - UntypedSymbol {} - MismatchedType {} - Spirv(err: rspirv::dr::Error) { - from() - display("{}", err) - cause(err) - } - Unreachable {} - Todo {} - } +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable } #[cfg(debug_assertions)] -fn error_unreachable() -> TranslateError { +fn error_todo() -> TranslateError { unreachable!() } #[cfg(not(debug_assertions))] -fn error_unreachable() -> TranslateError { - TranslateError::Unreachable +fn error_todo() -> TranslateError { + TranslateError::Todo } #[cfg(debug_assertions)] @@ -663,112 +179,20 @@ fn error_mismatched_type() -> TranslateError { TranslateError::MismatchedType } -pub struct GlobalFnDeclResolver<'input, 'a> { - fns: &'a HashMap>, -} - -impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> { - self.fns.get(&id).ok_or_else(error_unknown_symbol) - } -} - -struct FnSigMapper<'input> { - // true - stays as return argument - // false - is moved to input argument - return_param_args: Vec, - func_decl: Rc>>, -} - -impl<'input> FnSigMapper<'input> { - fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self { - let return_param_args = method - .return_arguments - .iter() - .map(|a| a.state_space != ast::StateSpace::Param) - .collect::>(); - let mut new_return_arguments = Vec::new(); - for arg in method.return_arguments.into_iter() { - if arg.state_space == ast::StateSpace::Param { - method.input_arguments.push(arg); - } else { - new_return_arguments.push(arg); - } - } - method.return_arguments = new_return_arguments; - FnSigMapper { - return_param_args, - func_decl: Rc::new(RefCell::new(method)), - } - } - - fn resolve_in_spirv_repr( - &self, - data: ast::CallDetails, - arguments: ast::CallArgs>, - ) -> Result>, TranslateError> { - let func_decl = (*self.func_decl).borrow(); - let mut data_return = Vec::new(); - let mut arguments_return = Vec::new(); - let mut data_input = data.input_arguments; - let mut arguments_input = arguments.input_arguments; - let mut func_decl_return_iter = func_decl.return_arguments.iter(); - let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter(); - for (idx, id) in arguments.return_arguments.iter().enumerate() { - let stays_as_return = match self.return_param_args.get(idx) { - Some(x) => *x, - None => return Err(TranslateError::MismatchedType), - }; - if stays_as_return { - if let Some(var) = func_decl_return_iter.next() { - data_return.push((var.v_type.clone(), var.state_space)); - arguments_return.push(*id); - } else { - return Err(TranslateError::MismatchedType); - } - } else { - if let Some(var) = func_decl_input_iter.next() { - data_input.push((var.v_type.clone(), var.state_space)); - arguments_input.push(ast::ParsedOperand::Reg(*id)); - } else { - return Err(TranslateError::MismatchedType); - } - } - } - if arguments_return.len() != func_decl.return_arguments.len() - || arguments_input.len() != func_decl.input_arguments.len() - { - return Err(TranslateError::MismatchedType); - } - let data = ast::CallDetails { - uniform: data.uniform, - return_arguments: data_return, - input_arguments: data_input, - }; - let arguments = ast::CallArgs { - func: arguments.func, - return_arguments: arguments_return, - input_arguments: arguments_input, - }; - Ok(ast::Instruction::Call { data, arguments }) - } -} - enum Statement { Label(SpirvWord), Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), - LoadVar(LoadVarDetails), - StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), - RetValue(ast::RetData, SpirvWord), + RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), FunctionPointer(FunctionPointerDetails), - VectorAccess(VectorAccess), + VectorRead(VectorRead), + VectorWrite(VectorWrite), } impl> Statement, T> { @@ -813,52 +237,6 @@ impl> Statement, T> { if_false, }) } - Statement::LoadVar(LoadVarDetails { - arg, - typ, - member_index, - }) => { - let dst = visitor.visit_ident( - arg.dst, - Some((&typ, ast::StateSpace::Reg)), - true, - false, - )?; - let src = visitor.visit_ident( - arg.src, - Some((&typ, ast::StateSpace::Local)), - false, - false, - )?; - Statement::LoadVar(LoadVarDetails { - arg: ast::LdArgs { dst, src }, - typ, - member_index, - }) - } - Statement::StoreVar(StoreVarDetails { - arg, - typ, - member_index, - }) => { - let src1 = visitor.visit_ident( - arg.src1, - Some((&typ, ast::StateSpace::Local)), - false, - false, - )?; - let src2 = visitor.visit_ident( - arg.src2, - Some((&typ, ast::StateSpace::Reg)), - false, - false, - )?; - Statement::StoreVar(StoreVarDetails { - arg: ast::StArgs { src1, src2 }, - typ, - member_index, - }) - } Statement::Conversion(ImplicitConversion { src, dst, @@ -900,9 +278,20 @@ impl> Statement, T> { Statement::Constant(ConstantDefinition { dst, typ, value }) } Statement::RetValue(data, value) => { - // TODO: - // We should report type here - let value = visitor.visit_ident(value, None, false, false)?; + let value = value + .into_iter() + .map(|(ident, type_)| { + Ok(( + visitor.visit_ident( + ident, + Some((&type_, ast::StateSpace::Local)), + false, + false, + )?, + type_, + )) + }) + .collect::, _>>()?; Statement::RetValue(data, value) } Statement::PtrAccess(PtrAccess { @@ -937,33 +326,69 @@ impl> Statement, T> { offset_src, }) } - Statement::VectorAccess(VectorAccess { + Statement::VectorRead(VectorRead { scalar_type, vector_width, - dst, - src: vector_src, + scalar_dst: dst, + vector_src, member, }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); let dst: SpirvWord = visitor.visit_ident( dst, - Some((&scalar_type.into(), ast::StateSpace::Reg)), + Some((&scalar_t, ast::StateSpace::Reg)), true, false, )?; let src = visitor.visit_ident( vector_src, - Some(( - &ast::Type::Vector(vector_width, scalar_type), - ast::StateSpace::Reg, - )), + Some((&vector_t, ast::StateSpace::Reg)), false, false, )?; - Statement::VectorAccess(VectorAccess { + Statement::VectorRead(VectorRead { + scalar_type, + vector_width, + scalar_dst: dst, + vector_src: src, + member, + }) + } + Statement::VectorWrite(VectorWrite { + scalar_type, + vector_width, + vector_dst, + vector_src, + scalar_src, + member, + }) => { + let scalar_t = scalar_type.into(); + let vector_t = ast::Type::Vector(vector_width, scalar_type); + let vector_dst = visitor.visit_ident( + vector_dst, + Some((&vector_t, ast::StateSpace::Reg)), + true, + false, + )?; + let vector_src = visitor.visit_ident( + vector_src, + Some((&vector_t, ast::StateSpace::Reg)), + false, + false, + )?; + let scalar_src = visitor.visit_ident( + scalar_src, + Some((&scalar_t, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::VectorWrite(VectorWrite { + vector_dst, + vector_src, + scalar_src, scalar_type, vector_width, - dst, - src, member, }) } @@ -1049,22 +474,6 @@ struct BrachCondition { if_true: SpirvWord, if_false: SpirvWord, } -struct LoadVarDetails { - arg: ast::LdArgs, - typ: ast::Type, - // (index, vector_width) - // HACK ALERT - // For some reason IGC explodes when you try to load from builtin vectors - // using OpInBoundsAccessChain, the one true way to do it is to - // OpLoad+OpCompositeExtract - member_index: Option<(u8, Option)>, -} - -struct StoreVarDetails { - arg: ast::StArgs, - typ: ast::Type, - member_index: Option, -} #[derive(Clone)] struct ImplicitConversion { @@ -1115,14 +524,14 @@ struct FunctionPointerDetails { } #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -struct SpirvWord(spirv::Word); +pub struct SpirvWord(u32); -impl From for SpirvWord { - fn from(value: spirv::Word) -> Self { +impl From for SpirvWord { + fn from(value: u32) -> Self { Self(value) } } -impl From for spirv::Word { +impl From for u32 { fn from(value: SpirvWord) -> Self { value.0 } @@ -1136,31 +545,6 @@ impl ast::Operand for SpirvWord { } } -fn pred_map_variable Result>( - this: ast::PredAt, - f: &mut F, -) -> Result, TranslateError> { - let new_label = f(this.label)?; - Ok(ast::PredAt { - not: this.not, - label: new_label, - }) -} - -pub(crate) enum Directive<'input> { - Variable(ast::LinkingDirective, ast::Variable), - Method(Function<'input>), -} - -pub(crate) struct Function<'input> { - pub func_decl: Rc>>, - pub globals: Vec>, - pub body: Option>, - import_as: Option, - tuning: Vec, - linkage: ast::LinkingDirective, -} - type ExpandedStatement = Statement, SpirvWord>; type NormalizedStatement = Statement< @@ -1171,577 +555,12 @@ type NormalizedStatement = Statement< ast::ParsedOperand, >; -type UnconditionalStatement = - Statement>, ast::ParsedOperand>; - -type TypedStatement = Statement, TypedOperand>; - -#[derive(Copy, Clone)] -enum TypedOperand { - Reg(SpirvWord), - RegOffset(SpirvWord, i32), - Imm(ast::ImmediateValue), - VecMember(SpirvWord, u8), -} - -impl TypedOperand { - fn map( - self, - fn_: impl FnOnce(SpirvWord, Option) -> Result, - ) -> Result { - Ok(match self { - TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?), - TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off), - TypedOperand::Imm(imm) => TypedOperand::Imm(imm), - TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), - }) - } - - fn underlying_register(&self) -> Option { - match self { - Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r), - Self::Imm(_) => None, - } - } - - fn unwrap_reg(&self) -> Result { - match self { - TypedOperand::Reg(reg) => Ok(*reg), - _ => Err(error_unreachable()), - } - } -} - -impl ast::Operand for TypedOperand { - type Ident = SpirvWord; - - fn from_ident(ident: Self::Ident) -> Self { - TypedOperand::Reg(ident) - } -} - -impl ast::VisitorMap - for FnVisitor -where - Fn: FnMut( - TypedOperand, - Option<(&ast::Type, ast::StateSpace)>, - bool, - bool, - ) -> Result, -{ - fn visit( - &mut self, - args: TypedOperand, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - (self.fn_)(args, type_space, is_dst, relaxed_type_check) - } - - fn visit_ident( - &mut self, - args: SpirvWord, - type_space: Option<(&ast::Type, ast::StateSpace)>, - is_dst: bool, - relaxed_type_check: bool, - ) -> Result { - match (self.fn_)( - TypedOperand::Reg(args), - type_space, - is_dst, - relaxed_type_check, - )? { - TypedOperand::Reg(reg) => Ok(reg), - _ => Err(TranslateError::Unreachable), - } - } -} - -struct FnVisitor< - T, - U, - Err, - Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, -> { - fn_: Fn, - _marker: PhantomData Result>, -} - -impl< - T, - U, - Err, - Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, - > FnVisitor -{ - fn new(fn_: Fn) -> Self { - Self { - fn_, - _marker: PhantomData, - } - } -} - -fn register_external_fn_call<'a>( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - name: String, - return_arguments: impl Iterator, - input_arguments: impl Iterator, -) -> Result { - match ptx_impl_imports.entry(name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); - let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); - let func_decl = ast::MethodDeclaration:: { - return_arguments, - name: ast::MethodName::Func(fn_id), - input_arguments, - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - }; - entry.insert(Directive::Method(func)); - Ok(fn_id) - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => Ok(fn_id), - ast::MethodName::Kernel(_) => Err(error_unreachable()), - }, - _ => Err(error_unreachable()), - }, - } -} - -fn fn_arguments_to_variables<'a>( - id_defs: &mut NumericIdResolver, - args: impl Iterator, -) -> Vec> { - args.map(|(typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() -} - -fn hoist_function_globals(directives: Vec) -> Vec { - let mut result = Vec::with_capacity(directives.len()); - for directive in directives { - match directive { - Directive::Method(method) => { - for variable in method.globals { - result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); - } - result.push(Directive::Method(Function { - globals: Vec::new(), - ..method - })) - } - _ => result.push(directive), - } - } - result -} - -struct MethodsCallMap<'input> { - map: HashMap, HashSet>, -} - -impl<'input> MethodsCallMap<'input> { - fn new(module: &[Directive<'input>]) -> Self { - let mut directly_called_by = HashMap::new(); - for directive in module { - match directive { - Directive::Method(Function { - func_decl, - body: Some(statements), - .. - }) => { - let call_key: ast::MethodName<_> = (**func_decl).borrow().name; - if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { - entry.insert(Vec::new()); - } - for statement in statements { - match statement { - Statement::Instruction(ast::Instruction::Call { data, arguments }) => { - multi_hash_map_append( - &mut directly_called_by, - call_key, - arguments.func, - ); - } - _ => {} - } - } - } - _ => {} - } - } - let mut result = HashMap::new(); - for (&method_key, children) in directly_called_by.iter() { - let mut visited = HashSet::new(); - for child in children { - Self::add_call_map_single(&directly_called_by, &mut visited, *child); - } - result.insert(method_key, visited); - } - MethodsCallMap { map: result } - } - - fn add_call_map_single( - directly_called_by: &HashMap, Vec>, - visited: &mut HashSet, - current: SpirvWord, - ) { - if !visited.insert(current) { - return; - } - if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { - for child in children { - Self::add_call_map_single(directly_called_by, visited, *child); - } - } - } - - fn get_kernel_children(&self, name: &'input str) -> impl Iterator { - self.map - .get(&ast::MethodName::Kernel(name)) - .into_iter() - .flatten() - } - - fn kernels(&self) -> impl Iterator)> { - self.map - .iter() - .filter_map(|(method, children)| match method { - ast::MethodName::Kernel(kernel) => Some((*kernel, children)), - ast::MethodName::Func(..) => None, - }) - } - - fn methods( - &self, - ) -> impl Iterator, &HashSet)> { - self.map - .iter() - .map(|(method, children)| (*method, children)) - } - - fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) { - self.map - .get(&method) - .into_iter() - .flatten() - .copied() - .for_each(f); - } -} - -fn multi_hash_map_append< - K: Eq + std::hash::Hash, - V, - Collection: std::iter::Extend + std::default::Default, ->( - m: &mut HashMap, - key: K, - value: V, -) { - match m.entry(key) { - hash_map::Entry::Occupied(mut entry) => { - entry.get_mut().extend(iter::once(value)); - } - hash_map::Entry::Vacant(entry) => { - entry.insert(Default::default()).extend(iter::once(value)); - } - } -} - -fn normalize_variable_decls(directives: &mut Vec) { - for directive in directives { - match directive { - Directive::Method(Function { - body: Some(func), .. - }) => { - func[1..].sort_by_key(|s| match s { - Statement::Variable(_) => 0, - _ => 1, - }); - } - _ => (), - } - } -} - -// HACK ALERT! -// This function is a "good enough" heuristic of whetever to mark f16/f32 operations -// in the kernel as flushing denorms to zero or preserving them -// PTX support per-instruction ftz information. Unfortunately SPIR-V has no -// such capability, so instead we guesstimate which use is more common in the kernel -// and emit suitable execution mode -fn compute_denorm_information<'input>( - module: &[Directive<'input>], -) -> HashMap, HashMap> { - let mut denorm_methods = HashMap::new(); - for directive in module { - match directive { - Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} - Directive::Method(Function { - func_decl, - body: Some(statements), - .. - }) => { - let mut flush_counter = DenormCountMap::new(); - let method_key = (**func_decl).borrow().name; - for statement in statements { - match statement { - Statement::Instruction(inst) => { - if let Some((flush, width)) = flush_to_zero(inst) { - denorm_count_map_update(&mut flush_counter, width, flush); - } - } - Statement::LoadVar(..) => {} - Statement::StoreVar(..) => {} - Statement::Conditional(_) => {} - Statement::Conversion(_) => {} - Statement::Constant(_) => {} - Statement::RetValue(_, _) => {} - Statement::Label(_) => {} - Statement::Variable(_) => {} - Statement::PtrAccess { .. } => {} - Statement::VectorAccess { .. } => {} - Statement::RepackVector(_) => {} - Statement::FunctionPointer(_) => {} - } - } - denorm_methods.insert(method_key, flush_counter); - } - } - } - denorm_methods - .into_iter() - .map(|(name, v)| { - let width_to_denorm = v - .into_iter() - .map(|(k, flush_over_preserve)| { - let mode = if flush_over_preserve > 0 { - spirv::FPDenormMode::FlushToZero - } else { - spirv::FPDenormMode::Preserve - }; - (k, (mode, flush_over_preserve)) - }) - .collect(); - (name, width_to_denorm) - }) - .collect() -} - -fn flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { - match this { - ast::Instruction::Ld { .. } => None, - ast::Instruction::St { .. } => None, - ast::Instruction::Mov { .. } => None, - ast::Instruction::Not { .. } => None, - ast::Instruction::Bra { .. } => None, - ast::Instruction::Shl { .. } => None, - ast::Instruction::Shr { .. } => None, - ast::Instruction::Ret { .. } => None, - ast::Instruction::Call { .. } => None, - ast::Instruction::Or { .. } => None, - ast::Instruction::And { .. } => None, - ast::Instruction::Cvta { .. } => None, - ast::Instruction::Selp { .. } => None, - ast::Instruction::Bar { .. } => None, - ast::Instruction::Atom { .. } => None, - ast::Instruction::AtomCas { .. } => None, - ast::Instruction::Sub { - data: ast::ArithDetails::Integer(_), - .. - } => None, - ast::Instruction::Add { - data: ast::ArithDetails::Integer(_), - .. - } => None, - ast::Instruction::Mul { - data: ast::MulDetails::Integer { .. }, - .. - } => None, - ast::Instruction::Mad { - data: ast::MadDetails::Integer { .. }, - .. - } => None, - ast::Instruction::Min { - data: ast::MinMaxDetails::Signed(_), - .. - } => None, - ast::Instruction::Min { - data: ast::MinMaxDetails::Unsigned(_), - .. - } => None, - ast::Instruction::Max { - data: ast::MinMaxDetails::Signed(_), - .. - } => None, - ast::Instruction::Max { - data: ast::MinMaxDetails::Unsigned(_), - .. - } => None, - ast::Instruction::Cvt { - data: - ast::CvtDetails { - mode: - ast::CvtMode::ZeroExtend - | ast::CvtMode::SignExtend - | ast::CvtMode::Truncate - | ast::CvtMode::Bitcast - | ast::CvtMode::SaturateUnsignedToSigned - | ast::CvtMode::SaturateSignedToUnsigned - | ast::CvtMode::FPFromSigned(_) - | ast::CvtMode::FPFromUnsigned(_), - .. - }, - .. - } => None, - ast::Instruction::Div { - data: ast::DivDetails::Unsigned(_), - .. - } => None, - ast::Instruction::Div { - data: ast::DivDetails::Signed(_), - .. - } => None, - ast::Instruction::Clz { .. } => None, - ast::Instruction::Brev { .. } => None, - ast::Instruction::Popc { .. } => None, - ast::Instruction::Xor { .. } => None, - ast::Instruction::Bfe { .. } => None, - ast::Instruction::Bfi { .. } => None, - ast::Instruction::Rem { .. } => None, - ast::Instruction::Prmt { .. } => None, - ast::Instruction::Activemask { .. } => None, - ast::Instruction::Membar { .. } => None, - ast::Instruction::Sub { - data: ast::ArithDetails::Float(float_control), - .. - } - | ast::Instruction::Add { - data: ast::ArithDetails::Float(float_control), - .. - } - | ast::Instruction::Mul { - data: ast::MulDetails::Float(float_control), - .. - } - | ast::Instruction::Mad { - data: ast::MadDetails::Float(float_control), - .. - } => float_control - .flush_to_zero - .map(|ftz| (ftz, float_control.type_.size_of())), - ast::Instruction::Fma { data, .. } => { - data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) - } - ast::Instruction::Setp { data, .. } => { - data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) - } - ast::Instruction::SetpBool { data, .. } => data - .base - .flush_to_zero - .map(|ftz| (ftz, data.base.type_.size_of())), - ast::Instruction::Abs { data, .. } - | ast::Instruction::Rsqrt { data, .. } - | ast::Instruction::Neg { data, .. } - | ast::Instruction::Ex2 { data, .. } => { - data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) - } - ast::Instruction::Min { - data: ast::MinMaxDetails::Float(float_control), - .. - } - | ast::Instruction::Max { - data: ast::MinMaxDetails::Float(float_control), - .. - } => float_control - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())), - ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => { - data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) - } - // Modifier .ftz can only be specified when either .dtype or .atype - // is .f32 and applies only to single precision (.f32) inputs and results. - ast::Instruction::Cvt { - data: - ast::CvtDetails { - mode: - ast::CvtMode::FPExtend { flush_to_zero } - | ast::CvtMode::FPTruncate { flush_to_zero, .. } - | ast::CvtMode::FPRound { flush_to_zero, .. } - | ast::CvtMode::SignedFromFP { flush_to_zero, .. } - | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. }, - .. - }, - .. - } => flush_to_zero.map(|ftz| (ftz, 4)), - ast::Instruction::Div { - data: - ast::DivDetails::Float(ast::DivFloatDetails { - type_, - flush_to_zero, - .. - }), - .. - } => flush_to_zero.map(|ftz| (ftz, type_.size_of())), - ast::Instruction::Sin { data, .. } - | ast::Instruction::Cos { data, .. } - | ast::Instruction::Lg2 { data, .. } => { - Some((data.flush_to_zero, mem::size_of::() as u8)) - } - ptx_parser::Instruction::PrmtSlow { .. } => None, - ptx_parser::Instruction::Trap {} => None, - } -} - -type DenormCountMap = HashMap; - -fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { - let num_value = if value { 1 } else { -1 }; - denorm_count_map_update_impl(map, key, num_value); -} - -fn denorm_count_map_update_impl( - map: &mut DenormCountMap, - key: T, - num_value: isize, -) { - match map.entry(key) { - hash_map::Entry::Occupied(mut counter) => { - *(counter.get_mut()) += num_value; - } - hash_map::Entry::Vacant(entry) => { - entry.insert(num_value); - } - } -} - -pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { +enum Directive2<'input, Instruction, Operand: ast::Operand> { Variable(ast::LinkingDirective, ast::Variable), Method(Function2<'input, Instruction, Operand>), } -pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> { +struct Function2<'input, Instruction, Operand: ast::Operand> { pub func_decl: ast::MethodDeclaration<'input, SpirvWord>, pub globals: Vec>, pub body: Option>>, @@ -1861,6 +680,41 @@ impl<'input, 'b> ScopedResolver<'input, 'b> { scope.flush(self.flat_resolver); } + fn add_or_get_in_current_scope_untyped( + &mut self, + name: &'input str, + ) -> Result { + let current_scope = self.scopes.last_mut().unwrap(); + Ok( + match current_scope.name_to_ident.entry(Cow::Borrowed(name)) { + hash_map::Entry::Occupied(occupied_entry) => { + let ident = *occupied_entry.get(); + let entry = current_scope + .ident_map + .get(&ident) + .ok_or_else(|| error_unreachable())?; + if entry.type_space.is_some() { + return Err(error_unknown_symbol()); + } + ident + } + hash_map::Entry::Vacant(vacant_entry) => { + let new_id = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + vacant_entry.insert(new_id); + current_scope.ident_map.insert( + new_id, + IdentEntry { + name: Some(Cow::Borrowed(name)), + type_space: None, + }, + ); + new_id + } + }, + ) + } + fn add( &mut self, name: Cow<'input, str>, @@ -1949,19 +803,6 @@ impl SpecialRegistersMap2 { self.id_to_reg.get(&id).copied() } - fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { - match self.reg_to_id.entry(reg) { - hash_map::Entry::Occupied(e) => *e.get(), - hash_map::Entry::Vacant(e) => { - let numeric_id = SpirvWord(current_id.0); - current_id.0 += 1; - e.insert(numeric_id); - self.id_to_reg.insert(numeric_id, reg); - numeric_id - } - } - } - fn generate_declarations<'a, 'input>( resolver: &'a mut GlobalStringIdentResolver2<'input>, ) -> impl ExactSizeIterator< @@ -1975,7 +816,7 @@ impl SpecialRegistersMap2 { let name = ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); let return_type = sreg.get_function_return_type(); - let input_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_input_type(); ( sreg, ast::MethodDeclaration { @@ -1988,14 +829,17 @@ impl SpecialRegistersMap2 { array_init: Vec::new(), }], name: name, - input_arguments: vec![ast::Variable { - align: None, - v_type: input_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }], + input_arguments: input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(), shared_mem: None, }, ) @@ -2003,10 +847,49 @@ impl SpecialRegistersMap2 { } } -pub struct VectorAccess { +pub struct VectorRead { scalar_type: ast::ScalarType, vector_width: u8, - dst: SpirvWord, - src: SpirvWord, + scalar_dst: SpirvWord, + vector_src: SpirvWord, member: u8, } + +pub struct VectorWrite { + scalar_type: ast::ScalarType, + vector_width: u8, + vector_dst: SpirvWord, + vector_src: SpirvWord, + scalar_src: SpirvWord, + member: u8, +} + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} + +type UnconditionalStatement = + Statement>, ast::ParsedOperand>; diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs deleted file mode 100644 index b598345..0000000 --- a/ptx/src/pass/normalize_identifiers.rs +++ /dev/null @@ -1,80 +0,0 @@ -use super::*; -use ptx_parser as ast; - -pub(crate) fn run<'input, 'b>( - id_defs: &mut FnStringIdResolver<'input, 'b>, - fn_defs: &GlobalFnDeclResolver<'input, 'b>, - func: Vec>>, -) -> Result, TranslateError> { - for s in func.iter() { - match s { - ast::Statement::Label(id) => { - id_defs.add_def(*id, None, false); - } - _ => (), - } - } - let mut result = Vec::new(); - for s in func { - expand_map_variables(id_defs, fn_defs, &mut result, s)?; - } - Ok(result) -} - -fn expand_map_variables<'a, 'b>( - id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver<'a, 'b>, - result: &mut Vec, - s: ast::Statement>, -) -> Result<(), TranslateError> { - match s { - ast::Statement::Block(block) => { - id_defs.start_block(); - for s in block { - expand_map_variables(id_defs, fn_defs, result, s)?; - } - id_defs.end_block(); - } - ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), - ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( - p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) - .transpose()?, - ast::visit_map(i, &mut |id, - _: Option<(&ast::Type, ast::StateSpace)>, - _: bool, - _: bool| { - id_defs.get_id(id) - })?, - ))), - ast::Statement::Variable(var) => { - let var_type = var.var.v_type.clone(); - match var.count { - Some(count) => { - for new_id in - id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) - { - result.push(Statement::Variable(ast::Variable { - align: var.var.align, - v_type: var.var.v_type.clone(), - state_space: var.var.state_space, - name: new_id, - array_init: var.var.array_init.clone(), - })) - } - } - None => { - let new_id = - id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); - result.push(Statement::Variable(ast::Variable { - align: var.var.align, - v_type: var.var.v_type.clone(), - state_space: var.var.state_space, - name: new_id, - array_init: var.var.array_init, - })); - } - } - } - }; - Ok(()) -} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index beaf08b..5155886 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -1,6 +1,5 @@ use super::*; use ptx_parser as ast; -use rustc_hash::FxHashMap; pub(crate) fn run<'input, 'b>( resolver: &mut ScopedResolver<'input, 'b>, @@ -37,7 +36,7 @@ fn run_method<'input, 'b>( let name = match method.func_directive.name { ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), ast::MethodName::Func(text) => { - ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?) + ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?) } }; resolver.start_scope(); diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs deleted file mode 100644 index 037e918..0000000 --- a/ptx/src/pass/normalize_labels.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::{collections::HashSet, iter}; - -use super::*; - -pub(super) fn run( - func: Vec, - id_def: &mut NumericIdResolver, -) -> Vec { - let mut labels_in_use = HashSet::new(); - for s in func.iter() { - match s { - Statement::Instruction(i) => { - if let Some(target) = jump_target(i) { - labels_in_use.insert(target); - } - } - Statement::Conditional(cond) => { - labels_in_use.insert(cond.if_true); - labels_in_use.insert(cond.if_false); - } - Statement::Variable(..) - | Statement::LoadVar(..) - | Statement::StoreVar(..) - | Statement::RetValue(..) - | Statement::Conversion(..) - | Statement::Constant(..) - | Statement::Label(..) - | Statement::PtrAccess { .. } - | Statement::VectorAccess { .. } - | Statement::RepackVector(..) - | Statement::FunctionPointer(..) => {} - } - } - iter::once(Statement::Label(id_def.register_intermediate(None))) - .chain(func.into_iter().filter(|s| match s { - Statement::Label(i) => labels_in_use.contains(i), - _ => true, - })) - .collect::>() -} - -fn jump_target>( - this: &ast::Instruction, -) -> Option { - match this { - ast::Instruction::Bra { arguments } => Some(arguments.src), - _ => None, - } -} diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs deleted file mode 100644 index c971cfa..0000000 --- a/ptx/src/pass/normalize_predicates.rs +++ /dev/null @@ -1,44 +0,0 @@ -use super::*; -use ptx_parser as ast; - -pub(crate) fn run( - func: Vec, - id_def: &mut NumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func { - match s { - Statement::Label(id) => result.push(Statement::Label(id)), - Statement::Instruction((pred, inst)) => { - if let Some(pred) = pred { - let if_true = id_def.register_intermediate(None); - let if_false = id_def.register_intermediate(None); - let folded_bra = match &inst { - ast::Instruction::Bra { arguments, .. } => Some(arguments.src), - _ => None, - }; - let mut branch = BrachCondition { - predicate: pred.label, - if_true: folded_bra.unwrap_or(if_true), - if_false, - }; - if pred.not { - std::mem::swap(&mut branch.if_true, &mut branch.if_false); - } - result.push(Statement::Conditional(branch)); - if folded_bra.is_none() { - result.push(Statement::Label(if_true)); - result.push(Statement::Instruction(inst)); - } - result.push(Statement::Label(if_false)); - } else { - result.push(Statement::Instruction(inst)); - } - } - Statement::Variable(var) => result.push(Statement::Variable(var)), - // Blocks are flattened when resolving ids - _ => return Err(error_unreachable()), - } - } - Ok(result) -} diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs new file mode 100644 index 0000000..70d77d3 --- /dev/null +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -0,0 +1,187 @@ +use super::*; + +pub(super) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut fn_declarations = FxHashMap::default(); + let remapped_directives = directives + .into_iter() + .map(|directive| run_directive(resolver, &mut fn_declarations, directive)) + .collect::, _>>()?; + let mut result = fn_declarations + .into_iter() + .map(|(_, (return_arguments, name, input_arguments))| { + Directive2::Method(Function2 { + func_decl: ast::MethodDeclaration { + return_arguments, + name: ast::MethodName::Func(name), + input_arguments, + shared_mem: None, + }, + globals: Vec::new(), + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }) + }) + .collect::>(); + result.extend(remapped_directives); + Ok(result) +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, fn_declarations, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + statements + .into_iter() + .map(|statement| { + Ok(match statement { + Statement::Instruction(instruction) => { + Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?) + } + s => s, + }) + }) + .collect::, _>>() +} + +fn run_instruction<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + instruction: ptx_parser::Instruction, +) -> Result, TranslateError> { + Ok(match instruction { + i @ ptx_parser::Instruction::Activemask { .. } => { + to_call(resolver, fn_declarations, "activemask".into(), i)? + } + i @ ptx_parser::Instruction::Bfe { data, .. } => { + let name = ["bfe_", scalar_to_ptx_name(data)].concat(); + to_call(resolver, fn_declarations, name.into(), i)? + } + i @ ptx_parser::Instruction::Bfi { data, .. } => { + let name = ["bfi_", scalar_to_ptx_name(data)].concat(); + to_call(resolver, fn_declarations, name.into(), i)? + } + i => i, + }) +} + +fn to_call<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + fn_declarations: &mut FxHashMap< + Cow<'input, str>, + ( + Vec>, + SpirvWord, + Vec>, + ), + >, + name: Cow<'input, str>, + i: ast::Instruction, +) -> Result, TranslateError> { + let mut data_return = Vec::new(); + let mut data_input = Vec::new(); + let mut arguments_return = Vec::new(); + let mut arguments_input = Vec::new(); + ast::visit(&i, &mut |name: &SpirvWord, + type_space: Option<( + &ptx_parser::Type, + ptx_parser::StateSpace, + )>, + is_dst: bool, + _: bool| { + let (type_, space) = type_space.ok_or_else(error_mismatched_type)?; + if is_dst { + data_return.push((type_.clone(), space)); + arguments_return.push(*name); + } else { + data_input.push((type_.clone(), space)); + arguments_input.push(*name); + }; + Ok::<_, TranslateError>(()) + })?; + let fn_name = match fn_declarations.entry(name) { + hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1, + hash_map::Entry::Vacant(vacant_entry) => { + let name = vacant_entry.key().clone(); + let full_name = [ZLUDA_PTX_PREFIX, &*name].concat(); + let name = resolver.register_named(Cow::Owned(full_name.clone()), None); + vacant_entry.insert(( + to_variables(resolver, &data_return), + name, + to_variables(resolver, &data_input), + )); + name + } + }; + Ok(ast::Instruction::Call { + data: ptx_parser::CallDetails { + uniform: false, + return_arguments: data_return, + input_arguments: data_input, + }, + arguments: ptx_parser::CallArgs { + return_arguments: arguments_return, + func: fn_name, + input_arguments: arguments_input, + }, + }) +} + +fn to_variables<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>, +) -> Vec> { + arguments + .iter() + .map(|(type_, space)| ast::Variable { + align: None, + v_type: type_.clone(), + state_space: *space, + name: resolver.register_unnamed(Some((type_.clone(), *space))), + array_init: Vec::new(), + }) + .collect::>() +} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop deleted file mode 100644 index e3a4022..0000000 --- a/ptx/src/ptx.lalrpop +++ /dev/null @@ -1,2198 +0,0 @@ -use crate::ast; -use crate::ast::UnwrapWithVec; -use crate::{without_none, vector_index}; - -use lalrpop_util::ParseError; -use std::convert::TryInto; - -grammar<'err>(errors: &'err mut Vec, ast::PtxError>>); - -extern { - type Error = ast::PtxError; -} - -match { - r"\s+" => { }, - r"//[^\n\r]*[\n\r]*" => { }, - r"/\*[^*]*\*+(?:[^/*][^*]*\*+)*/" => { }, - r"0[fF][0-9a-zA-Z]{8}" => F32NumToken, - r"0[dD][0-9a-zA-Z]{16}" => F64NumToken, - r"0[xX][0-9a-zA-Z]+U?" => HexNumToken, - r"[0-9]+U?" => DecimalNumToken, - r#""[^"]*""# => String, - r"[0-9]+\.[0-9]+" => VersionNumber, - "!", - "(", ")", - "+", - "-", - ",", - ".", - ":", - ";", - "@", - "[", "]", - "{", "}", - "<", ">", - "|", - "=", - ".acq_rel", - ".acquire", - ".add", - ".address_size", - ".align", - ".aligned", - ".and", - ".approx", - ".b16", - ".b32", - ".b64", - ".b8", - ".ca", - ".cas", - ".cg", - ".const", - ".cs", - ".cta", - ".cv", - ".dec", - ".entry", - ".eq", - ".equ", - ".exch", - ".extern", - ".f16", - ".f16x2", - ".f32", - ".f64", - ".file", - ".ftz", - ".full", - ".func", - ".ge", - ".geu", - ".gl", - ".global", - ".gpu", - ".gt", - ".gtu", - ".hi", - ".hs", - ".inc", - ".le", - ".leu", - ".lo", - ".loc", - ".local", - ".ls", - ".lt", - ".ltu", - ".lu", - ".max", - ".maxnreg", - ".maxntid", - ".minnctapersm", - ".min", - ".nan", - ".NaN", - ".nc", - ".ne", - ".neu", - ".num", - ".or", - ".param", - ".pragma", - ".pred", - ".reg", - ".relaxed", - ".release", - ".reqntid", - ".rm", - ".rmi", - ".rn", - ".rni", - ".rp", - ".rpi", - ".rz", - ".rzi", - ".s16", - ".s32", - ".s64", - ".s8" , - ".sat", - ".section", - ".shared", - ".sync", - ".sys", - ".target", - ".to", - ".u16", - ".u32", - ".u64", - ".u8" , - ".uni", - ".v2", - ".v4", - ".version", - ".visible", - ".volatile", - ".wb", - ".weak", - ".wide", - ".wt", - ".xor", -} else { - // IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID - "abs", - "activemask", - "add", - "and", - "atom", - "bar", - "barrier", - "bfe", - "bfi", - "bra", - "brev", - "call", - "clz", - "cos", - "cvt", - "cvta", - "debug", - "div", - "ex2", - "fma", - "ld", - "lg2", - "mad", - "map_f64_to_f32", - "max", - "membar", - "min", - "mov", - "mul", - "neg", - "not", - "or", - "popc", - "prmt", - "rcp", - "rem", - "ret", - "rsqrt", - "selp", - "setp", - "shl", - "shr", - "sin", - r"sm_[0-9]+" => ShaderModel, - "sqrt", - "st", - "sub", - "texmode_independent", - "texmode_unified", - "xor", -} else { - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers - r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID, - r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID, -} - -ExtendedID : &'input str = { - "abs", - "activemask", - "add", - "and", - "atom", - "bar", - "barrier", - "bfe", - "bfi", - "bra", - "brev", - "call", - "clz", - "cos", - "cvt", - "cvta", - "debug", - "div", - "ex2", - "fma", - "ld", - "lg2", - "mad", - "map_f64_to_f32", - "max", - "membar", - "min", - "mov", - "mul", - "neg", - "not", - "or", - "popc", - "prmt", - "rcp", - "rem", - "ret", - "rsqrt", - "selp", - "setp", - "shl", - "shr", - "sin", - ShaderModel, - "sqrt", - "st", - "sub", - "texmode_independent", - "texmode_unified", - "xor", - ID -} - -NumToken: (&'input str, u32, bool) = { - => { - if s.ends_with('U') { - (&s[2..s.len() - 1], 16, true) - } else { - (&s[2..], 16, false) - } - }, - => { - let radix = if s.starts_with('0') { 8 } else { 10 }; - if s.ends_with('U') { - (&s[..s.len() - 1], radix, true) - } else { - (s, radix, false) - } - } -} - -F32Num: f32 = { - => { - match u32::from_str_radix(&s[2..], 16) { - Ok(x) => unsafe { std::mem::transmute::<_, f32>(x) }, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0.0 - } - } - - } -} - -F64Num: f64 = { - => { - match u64::from_str_radix(&s[2..], 16) { - Ok(x) => unsafe { std::mem::transmute::<_, f64>(x) }, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0.0 - } - } - } -} - -U8Num: u8 = { - => { - let (text, radix, _) = x; - match u8::from_str_radix(text, radix) { - Ok(x) => x, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - } - } - } -} - -U16Num: u16 = { - => { - let (text, radix, _) = x; - match u16::from_str_radix(text, radix) { - Ok(x) => x, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - } - } - } -} - -U32Num: u32 = { - => { - let (text, radix, _) = x; - match u32::from_str_radix(text, radix) { - Ok(x) => x, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - } - } - } -} - -// TODO: handle negative number properly -S32Num: i32 = { - => { - let (text, radix, _) = x; - match i32::from_str_radix(text, radix) { - Ok(x) => if sign.is_some() { -x } else { x }, - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - } - } - } -} - -pub Module: ast::Module<'input> = { - Target => { - ast::Module { version: v, directives: without_none(d) } - } -}; - -Version: (u8, u8) = { - ".version" => { - let dot = v.find('.').unwrap(); - let major = v[..dot].parse::().unwrap_or_else(|err| { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - }); - let minor = v[dot+1..].parse::().unwrap_or_else(|err| { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - 0 - }); - (major,minor) - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-module-directives-target -Target = { - ".target" Comma -}; - -TargetSpecifier = { - ShaderModel, - "texmode_unified", - "texmode_independent", - "debug", - "map_f64_to_f32" -}; - -Directive: Option>> = { - AddressSize => None, - => { - let (linking, func) = f; - Some(ast::Directive::Method(linking, func)) - }, - File => None, - Section => None, - ";" => { - let (linking, var) = v; - Some(ast::Directive::Variable(linking, var)) - }, - @L ! @R => { - let (start, _, end)= (<>); - errors.push(ParseError::User { error: - ast::PtxError::UnrecognizedDirective { start, end } - }); - None - } -}; - -AddressSize = { - ".address_size" U8Num -}; - -Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>) = { - - - - => { - (linking, ast::Function{func_directive, tuning, body}) - } -}; - -LinkingDirective: ast::LinkingDirective = { - ".extern" => ast::LinkingDirective::EXTERN, - ".visible" => ast::LinkingDirective::VISIBLE, - ".weak" => ast::LinkingDirective::WEAK, -}; - -TuningDirective: ast::TuningDirective = { - ".maxnreg" => ast::TuningDirective::MaxNReg(ncta), - ".maxntid" => ast::TuningDirective::MaxNtid(nx, 1, 1), - ".maxntid" "," => ast::TuningDirective::MaxNtid(nx, ny, 1), - ".maxntid" "," "," => ast::TuningDirective::MaxNtid(nx, ny, nz), - ".reqntid" => ast::TuningDirective::ReqNtid(nx, 1, 1), - ".reqntid" "," => ast::TuningDirective::ReqNtid(nx, ny, 1), - ".reqntid" "," "," => ast::TuningDirective::ReqNtid(nx, ny, nz), - ".minnctapersm" => ast::TuningDirective::MinNCtaPerSm(ncta), -}; - -LinkingDirectives: ast::LinkingDirective = { - => { - ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y) - } -} - -MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { - ".entry" => { - let return_arguments = Vec::new(); - let name = ast::MethodName::Kernel(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } - }, - ".func" => { - let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); - let name = ast::MethodName::Func(name); - ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } - } -}; - -KernelArguments: Vec> = { - "(" > ")" => args -}; - -FnArguments: Vec> = { - "(" > ")" => args -}; - -KernelInput: ast::Variable<&'input str> = { - => { - let (align, v_type, name) = v; - ast::Variable { - align, - v_type, - state_space: ast::StateSpace::Param, - name, - array_init: Vec::new() - } - } -} - -FnInput: ast::Variable<&'input str> = { - => { - let (align, v_type, name) = v; - let state_space = ast::StateSpace::Reg; - ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } - }, - => { - let (align, v_type, name) = v; - let state_space = ast::StateSpace::Param; - ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } - } -} - -FunctionBody: Option>>> = { - "{" "}" => { Some(without_none(s)) }, - ";" => { None } -}; - -StateSpaceSpecifier: ast::StateSpace = { - ".reg" => ast::StateSpace::Reg, - ".const" => ast::StateSpace::Const, - ".global" => ast::StateSpace::Global, - ".local" => ast::StateSpace::Local, - ".shared" => ast::StateSpace::Shared, - ".param" => ast::StateSpace::Param, // used to prepare function call -}; - -#[inline] -ScalarType: ast::ScalarType = { - ".f16" => ast::ScalarType::F16, - ".f16x2" => ast::ScalarType::F16x2, - ".pred" => ast::ScalarType::Pred, - ".b8" => ast::ScalarType::B8, - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u8" => ast::ScalarType::U8, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s8" => ast::ScalarType::S8, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - -Statement: Option>> = { - => Some(ast::Statement::Label(l)), - DebugDirective => None, - ";" => Some(ast::Statement::Variable(v)), - ";" => Some(ast::Statement::Instruction(p, i)), - PragmaStatement => None, - "{" "}" => Some(ast::Statement::Block(without_none(s))), - @L ! ";" @R => { - let (start, _, _, end) = (<>); - errors.push(ParseError::User { error: - ast::PtxError::UnrecognizedStatement { start, end } - }); - None - } -}; - -PragmaStatement: () = { - ".pragma" String ";" -} - -DebugDirective: () = { - DebugLocation -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-loc -DebugLocation = { - ".loc" U32Num U32Num U32Num -}; - -Label: &'input str = { - ":" => id -}; - -Align: u32 = { - ".align" => x -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names -MultiVariable: ast::MultiVariable<&'input str> = { - => ast::MultiVariable{<>} -} - -VariableParam: u32 = { - "<" ">" => n -} - -Variable: ast::Variable<&'input str> = { - => { - let (align, v_type, name) = v; - let state_space = ast::StateSpace::Reg; - ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} - }, - LocalVariable, - => { - let (align, array_init, v_type, name) = v; - let state_space = ast::StateSpace::Param; - ast::Variable {align, v_type, state_space, name, array_init} - }, - SharedVariable, -}; - -RegVariable: (Option, ast::Type, &'input str) = { - ".reg" > => { - let (align, t, name) = var; - let v_type = ast::Type::Scalar(t); - (align, v_type, name) - }, - ".reg" > => { - let (align, v_len, t, name) = var; - let v_type = ast::Type::Vector(t, v_len); - (align, v_type, name) - } -} - -LocalVariable: ast::Variable<&'input str> = { - ".local" > => { - let (align, t, name) = var; - let v_type = ast::Type::Scalar(t); - let state_space = ast::StateSpace::Local; - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } - }, - ".local" > => { - let (align, v_len, t, name) = var; - let v_type = ast::Type::Vector(t, v_len); - let state_space = ast::StateSpace::Local; - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } - }, - ".local" > => { - let (align, t, name, arr_or_ptr) = var; - let state_space = ast::StateSpace::Local; - let (v_type, array_init) = match arr_or_ptr { - ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::Type::Array(t, dimensions), init) - } - ast::ArrayOrPointer::Pointer => { - errors.push(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); - (ast::Type::Array(t, Vec::new()), Vec::new()) - } - }; - ast::Variable { align, v_type, state_space, name, array_init } - } -} - -SharedVariable: ast::Variable<&'input str> = { - ".shared" > => { - let (align, t, name) = var; - let state_space = ast::StateSpace::Shared; - let v_type = ast::Type::Scalar(t); - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } - }, - ".shared" > => { - let (align, v_len, t, name) = var; - let state_space = ast::StateSpace::Shared; - let v_type = ast::Type::Vector(t, v_len); - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } - }, - ".shared" > => { - let (align, t, name, arr_or_ptr) = var; - let state_space = ast::StateSpace::Shared; - let (v_type, array_init) = match arr_or_ptr { - ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::Type::Array(t, dimensions), init) - } - ast::ArrayOrPointer::Pointer => { - errors.push(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); - (ast::Type::Array(t, Vec::new()), Vec::new()) - } - }; - ast::Variable { align, v_type, state_space, name, array_init } - } -} - -ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { - => { - let (align, v_type, name, array_init) = def; - (linking, ast::Variable { align, v_type, state_space, name, array_init }) - }, - > => { - let (align, t, name, arr_or_ptr) = var; - let (v_type, state_space, array_init) = match arr_or_ptr { - ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::Type::Array(t, dimensions), space, init) - } - ast::ArrayOrPointer::Pointer => { - if !linking.contains(ast::LinkingDirective::EXTERN) { - errors.push(ParseError::User { error: ast::PtxError::NonExternPointer }); - } - (ast::Type::Array(t, Vec::new()), space, Vec::new()) - } - }; - (linking, ast::Variable{ align, v_type, state_space, name, array_init }) - } -} - -VariableStateSpace: ast::StateSpace = { - ".const" => ast::StateSpace::Const, - ".global" => ast::StateSpace::Global, - ".shared" => ast::StateSpace::Shared, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option, Vec, ast::Type, &'input str) = { - ".param" > => { - let (align, t, name) = var; - let v_type = ast::Type::Scalar(t); - (align, Vec::new(), v_type, name) - }, - ".param" > => { - let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { - ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::Type::Array(t, dimensions), init) - } - ast::ArrayOrPointer::Pointer => { - (ast::Type::Scalar(t), Vec::new()) - } - }; - (align, array_init, v_type, name) - } -} - -ParamDeclaration: (Option, ast::Type, &'input str) = { - => { - let (align, array_init, v_type, name) = var; - if array_init.len() > 0 { - errors.push(ParseError::User { error: ast::PtxError::ArrayInitalizer }); - } - (align, v_type, name) - } -} - -GlobalVariableDefinitionNoArray: (Option, ast::Type, &'input str, Vec) = { - > => { - let (align, t, name) = scalar; - let v_type = ast::Type::Scalar(t); - (align, v_type, name, Vec::new()) - }, - > => { - let (align, v_len, t, name) = var; - let v_type = ast::Type::Vector(t, v_len); - (align, v_type, name, Vec::new()) - }, -} - -#[inline] -SizedScalarType: ast::ScalarType = { - ".b8" => ast::ScalarType::B8, - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u8" => ast::ScalarType::U8, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s8" => ast::ScalarType::S8, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f16" => ast::ScalarType::F16, - ".f16x2" => ast::ScalarType::F16x2, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -} - -#[inline] -LdStScalarType: ast::ScalarType = { - ".b8" => ast::ScalarType::B8, - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u8" => ast::ScalarType::U8, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s8" => ast::ScalarType::S8, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f16" => ast::ScalarType::F16, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -} - -Instruction: ast::Instruction> = { - InstLd, - InstMov, - InstMul, - InstAdd, - InstSetp, - InstNot, - InstBra, - InstCvt, - InstShl, - InstShr, - InstSt, - InstRet, - InstCvta, - InstCall, - InstAbs, - InstMad, - InstFma, - InstOr, - InstAnd, - InstSub, - InstMin, - InstMax, - InstRcp, - InstSelp, - InstBar, - InstAtom, - InstAtomCas, - InstDiv, - InstSqrt, - InstRsqrt, - InstNeg, - InstSin, - InstCos, - InstLg2, - InstEx2, - InstClz, - InstBrev, - InstPopc, - InstXor, - InstRem, - InstBfe, - InstBfi, - InstPrmt, - InstActivemask, - InstMembar, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld -InstLd: ast::Instruction> = { - "ld" "," => { - ast::Instruction::Ld( - ast::LdDetails { - qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StateSpace::Generic), - caching: cop.unwrap_or(ast::LdCacheOperator::Cached), - typ: t, - non_coherent: false - }, - ast::Arg2Ld { dst:dst, src:src } - ) - }, - "ld" ".global" "," => { - ast::Instruction::Ld( - ast::LdDetails { - qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ast::StateSpace::Global, - caching: cop.unwrap_or(ast::LdCacheOperator::Cached), - typ: t, - non_coherent: false - }, - ast::Arg2Ld { dst:dst, src:src } - ) - }, - "ld" ".global" ".nc" "," => { - ast::Instruction::Ld( - ast::LdDetails { - qualifier: ast::LdStQualifier::Weak, - state_space: ast::StateSpace::Global, - caching: cop.unwrap_or(ast::LdCacheOperator::Cached), - typ: t, - non_coherent: true - }, - ast::Arg2Ld { dst:dst, src:src } - ) - } -}; - -LdStType: ast::Type = { - => ast::Type::Vector(t, v), - => ast::Type::Scalar(t), -} - -LdStQualifier: ast::LdStQualifier = { - ".weak" => ast::LdStQualifier::Weak, - ".volatile" => ast::LdStQualifier::Volatile, - ".relaxed" => ast::LdStQualifier::Relaxed(s), - ".acquire" => ast::LdStQualifier::Acquire(s), -}; - -MemScope: ast::MemScope = { - ".cta" => ast::MemScope::Cta, - ".gpu" => ast::MemScope::Gpu, - ".sys" => ast::MemScope::Sys -}; - -MembarLevel: ast::MemScope = { - ".cta" => ast::MemScope::Cta, - ".gl" => ast::MemScope::Gpu, - ".sys" => ast::MemScope::Sys -}; - -LdNonGlobalStateSpace: ast::StateSpace = { - ".const" => ast::StateSpace::Const, - ".local" => ast::StateSpace::Local, - ".param" => ast::StateSpace::Param, - ".shared" => ast::StateSpace::Shared, -}; - -LdCacheOperator: ast::LdCacheOperator = { - ".ca" => ast::LdCacheOperator::Cached, - ".cg" => ast::LdCacheOperator::L2Only, - ".cs" => ast::LdCacheOperator::Streaming, - ".lu" => ast::LdCacheOperator::LastUse, - ".cv" => ast::LdCacheOperator::Uncached, -}; - -LdNcCacheOperator: ast::LdCacheOperator = { - ".ca" => ast::LdCacheOperator::Cached, - ".cg" => ast::LdCacheOperator::L2Only, - ".cs" => ast::LdCacheOperator::Streaming, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov -InstMov: ast::Instruction> = { - "mov" "," => { - let mov_type = match pref { - Some(vec_width) => ast::Type::Vector(t, vec_width), - None => ast::Type::Scalar(t) - }; - let details = ast::MovDetails::new(mov_type); - ast::Instruction::Mov( - details, - ast::Arg2Mov { dst, src } - ) - } -} - -#[inline] -MovScalarType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, - ".pred" => ast::ScalarType::Pred -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul -InstMul: ast::Instruction> = { - "mul" => ast::Instruction::Mul(d, a) -}; - -MulDetails: ast::MulDetails = { - => ast::MulDetails::Unsigned(ast::MulUInt{ - typ: t, - control: ctr - }), - => ast::MulDetails::Signed(ast::MulSInt{ - typ: t, - control: ctr - }), - => ast::MulDetails::Float(f) -}; - -MulIntControl: ast::MulIntControl = { - ".hi" => ast::MulIntControl::High, - ".lo" => ast::MulIntControl::Low, - ".wide" => ast::MulIntControl::Wide -}; - -#[inline] -RoundingModeFloat : ast::RoundingMode = { - ".rn" => ast::RoundingMode::NearestEven, - ".rz" => ast::RoundingMode::Zero, - ".rm" => ast::RoundingMode::NegativeInf, - ".rp" => ast::RoundingMode::PositiveInf, -}; - -RoundingModeInt : ast::RoundingMode = { - ".rni" => ast::RoundingMode::NearestEven, - ".rzi" => ast::RoundingMode::Zero, - ".rmi" => ast::RoundingMode::NegativeInf, - ".rpi" => ast::RoundingMode::PositiveInf, -}; - -IntType : ast::ScalarType = { - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -}; - -IntType3264: ast::ScalarType = { - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -} - -UIntType: ast::ScalarType = { - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, -}; - -SIntType: ast::ScalarType = { - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -}; - -FloatType: ast::ScalarType = { - ".f16" => ast::ScalarType::F16, - ".f16x2" => ast::ScalarType::F16x2, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add -InstAdd: ast::Instruction> = { - "add" => ast::Instruction::Add(d, a) -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp -// TODO: support f16 setp -InstSetp: ast::Instruction> = { - "setp" => ast::Instruction::Setp(d, a), - "setp" => ast::Instruction::SetpBool(d, a), -}; - -SetpMode: ast::SetpData = { - => ast::SetpData { - typ: t, - flush_to_zero: None, - cmp_op: cmp_op, - }, - ".f32" => ast::SetpData { - typ: ast::ScalarType::F32, - flush_to_zero: Some(ftz.is_some()), - cmp_op: cmp_op, - } - -}; - -SetpBoolMode: ast::SetpBoolData = { - => ast::SetpBoolData { - typ: t, - flush_to_zero: None, - cmp_op: cmp_op, - bool_op: bool_op, - }, - ".f32" => ast::SetpBoolData { - typ: ast::ScalarType::F32, - flush_to_zero: Some(ftz.is_some()), - cmp_op: cmp_op, - bool_op: bool_op, - } -}; - -SetpCompareOp: ast::SetpCompareOp = { - ".eq" => ast::SetpCompareOp::Eq, - ".ne" => ast::SetpCompareOp::NotEq, - ".lt" => ast::SetpCompareOp::Less, - ".le" => ast::SetpCompareOp::LessOrEq, - ".gt" => ast::SetpCompareOp::Greater, - ".ge" => ast::SetpCompareOp::GreaterOrEq, - ".lo" => ast::SetpCompareOp::Less, - ".ls" => ast::SetpCompareOp::LessOrEq, - ".hi" => ast::SetpCompareOp::Greater, - ".hs" => ast::SetpCompareOp::GreaterOrEq, - ".equ" => ast::SetpCompareOp::NanEq, - ".neu" => ast::SetpCompareOp::NanNotEq, - ".ltu" => ast::SetpCompareOp::NanLess, - ".leu" => ast::SetpCompareOp::NanLessOrEq, - ".gtu" => ast::SetpCompareOp::NanGreater, - ".geu" => ast::SetpCompareOp::NanGreaterOrEq, - ".num" => ast::SetpCompareOp::IsNotNan, - ".nan" => ast::SetpCompareOp::IsAnyNan, -}; - -SetpBoolPostOp: ast::SetpBoolPostOp = { - ".and" => ast::SetpBoolPostOp::And, - ".or" => ast::SetpBoolPostOp::Or, - ".xor" => ast::SetpBoolPostOp::Xor, -}; - -SetpTypeNoF32: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f64" => ast::ScalarType::F64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not -InstNot: ast::Instruction> = { - "not" => ast::Instruction::Not(t, a) -}; - -BooleanType: ast::ScalarType = { - ".pred" => ast::ScalarType::Pred, - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at -PredAt: ast::PredAt<&'input str> = { - "@" => ast::PredAt { not: false, label:label }, - "@" "!" => ast::PredAt { not: true, label:label } -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra -InstBra: ast::Instruction> = { - "bra" => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a) -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt -InstCvt: ast::Instruction> = { - "cvt" => { - ast::Instruction::Cvt(ast::CvtDetails::new_int_from_int_checked( - s.is_some(), - dst_t, - src_t, - errors - ), - a) - }, - "cvt" => { - ast::Instruction::Cvt(ast::CvtDetails::new_float_from_int_checked( - r, - f.is_some(), - s.is_some(), - dst_t, - src_t, - errors - ), - a) - }, - "cvt" => { - ast::Instruction::Cvt(ast::CvtDetails::new_int_from_float_checked( - r, - f.is_some(), - s.is_some(), - dst_t, - src_t, - errors - ), - a) - }, - "cvt" ".f16" ".f16" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: r, - flush_to_zero: None, - saturate: s.is_some(), - dst: ast::ScalarType::F16, - src: ast::ScalarType::F16 - } - ), a) - }, - "cvt" ".f32" ".f16" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: None, - flush_to_zero: Some(f.is_some()), - saturate: s.is_some(), - dst: ast::ScalarType::F32, - src: ast::ScalarType::F16 - } - ), a) - }, - "cvt" ".f64" ".f16" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: None, - flush_to_zero: None, - saturate: s.is_some(), - dst: ast::ScalarType::F64, - src: ast::ScalarType::F16 - } - ), a) - }, - "cvt" ".f16" ".f32" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: Some(r), - flush_to_zero: Some(f.is_some()), - saturate: s.is_some(), - dst: ast::ScalarType::F16, - src: ast::ScalarType::F32 - } - ), a) - }, - "cvt" ".f32" ".f32" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: r, - flush_to_zero: Some(f.is_some()), - saturate: s.is_some(), - dst: ast::ScalarType::F32, - src: ast::ScalarType::F32 - } - ), a) - }, - "cvt" ".f64" ".f32" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: None, - flush_to_zero: Some(f.is_some()), - saturate: s.is_some(), - dst: ast::ScalarType::F64, - src: ast::ScalarType::F32 - } - ), a) - }, - "cvt" ".f16" ".f64" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: Some(r), - flush_to_zero: None, - saturate: s.is_some(), - dst: ast::ScalarType::F16, - src: ast::ScalarType::F64 - } - ), a) - }, - "cvt" ".f32" ".f64" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: Some(r), - flush_to_zero: Some(s.is_some()), - saturate: s.is_some(), - dst: ast::ScalarType::F32, - src: ast::ScalarType::F64 - } - ), a) - }, - "cvt" ".f64" ".f64" => { - ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat( - ast::CvtDesc { - rounding: r, - flush_to_zero: None, - saturate: s.is_some(), - dst: ast::ScalarType::F64, - src: ast::ScalarType::F64 - } - ), a) - }, -}; - -CvtTypeInt: ast::ScalarType = { - ".u8" => ast::ScalarType::U8, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s8" => ast::ScalarType::S8, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -}; - -CvtTypeFloat: ast::ScalarType = { - ".f16" => ast::ScalarType::F16, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl -InstShl: ast::Instruction> = { - "shl" => ast::Instruction::Shl(t, a) -}; - -ShlType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr -InstShr: ast::Instruction> = { - "shr" => ast::Instruction::Shr(t, a) -}; - -ShrType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st -// Warning: NVIDIA documentation is incorrect, you can specify scope only once -InstSt: ast::Instruction> = { - "st" "," => { - ast::Instruction::St( - ast::StData { - qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StateSpace::Generic), - caching: cop.unwrap_or(ast::StCacheOperator::Writeback), - typ: t - }, - ast::Arg2St { src1:src1, src2:src2 } - ) - } -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors -MemoryOperand: ast::Operand<&'input str> = { - "[" "]" => o -} - -StStateSpace: ast::StateSpace = { - ".global" => ast::StateSpace::Global, - ".local" => ast::StateSpace::Local, - ".param" => ast::StateSpace::Param, - ".shared" => ast::StateSpace::Shared, -}; - -StCacheOperator: ast::StCacheOperator = { - ".wb" => ast::StCacheOperator::Writeback, - ".cg" => ast::StCacheOperator::L2Only, - ".cs" => ast::StCacheOperator::Streaming, - ".wt" => ast::StCacheOperator::Writethrough, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret -InstRet: ast::Instruction> = { - "ret" => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() }) -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta -InstCvta: ast::Instruction> = { - "cvta" => { - ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::StateSpace::Generic, - from, - size: s - }, - a) - }, - "cvta" ".to" => { - ast::Instruction::Cvta(ast::CvtaDetails { - to, - from: ast::StateSpace::Generic, - size: s - }, - a) - } -} - -CvtaStateSpace: ast::StateSpace = { - ".const" => ast::StateSpace::Const, - ".global" => ast::StateSpace::Global, - ".local" => ast::StateSpace::Local, - ".shared" => ast::StateSpace::Shared, -} - -CvtaSize: ast::CvtaSize = { - ".u32" => ast::CvtaSize::U32, - ".u64" => ast::CvtaSize::U64, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call -InstCall: ast::Instruction> = { - "call" => { - let (ret_params, func, param_list) = args; - ast::Instruction::Call(ast::CallInst { uniform: u.is_some(), ret_params, func, param_list }) - } -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs -InstAbs: ast::Instruction> = { - "abs" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: t }, a) - }, - "abs" ".f32" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F32 }, a) - }, - "abs" ".f64" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: None, typ: ast::ScalarType::F64 }, a) - }, - "abs" ".f16" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16 }, a) - }, - "abs" ".f16x2" => { - ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: Some(f.is_some()), typ: ast::ScalarType::F16x2 }, a) - }, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad -InstMad: ast::Instruction> = { - "mad" => ast::Instruction::Mad(d, a), - "mad" ".hi" ".sat" ".s32" => todo!(), -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma -InstFma: ast::Instruction> = { - "fma" => ast::Instruction::Fma(f, a), -}; - -SignedIntType: ast::ScalarType = { - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or -InstOr: ast::Instruction> = { - "or" => ast::Instruction::Or(d, a), -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and -InstAnd: ast::Instruction> = { - "and" => ast::Instruction::And(d, a), -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp -InstRcp: ast::Instruction> = { - "rcp" ".f32" => { - let details = ast::RcpDetails { - rounding, - flush_to_zero: Some(ftz.is_some()), - is_f64: false, - }; - ast::Instruction::Rcp(details, a) - }, - "rcp" ".f64" => { - let details = ast::RcpDetails { - rounding: Some(rn), - flush_to_zero: None, - is_f64: true, - }; - ast::Instruction::Rcp(details, a) - } -}; - -RcpRoundingMode: Option = { - ".approx" => None, - => Some(r) -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub -InstSub: ast::Instruction> = { - "sub" => ast::Instruction::Sub(d, a), -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min -InstMin: ast::Instruction> = { - "min" => ast::Instruction::Min(d, a), -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max -InstMax: ast::Instruction> = { - "max" => ast::Instruction::Max(d, a), -}; - -MinMaxDetails: ast::MinMaxDetails = { - => ast::MinMaxDetails::Unsigned(t), - => ast::MinMaxDetails::Signed(t), - ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 } - ), - ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 } - ), - ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 } - ), - ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 } - ) -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp -InstSelp: ast::Instruction> = { - "selp" => ast::Instruction::Selp(t, a), -}; - -SelpType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar -InstBar: ast::Instruction> = { - "bar" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), - "barrier" ".sync" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), - "barrier" ".sync" ".aligned" => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom -// The documentation does not mention all spported operations: -// * Operation .add requires .u32 or .s32 or .u64 or .f64 or f16 or f16x2 or .f32 -// * Operation .inc requires .u32 type for instuction -// * Operation .dec requires .u32 type for instuction -// Otherwise as documented -InstAtom: ast::Instruction> = { - "atom" => { - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Bit { op, typ } - }; - ast::Instruction::Atom(details,a) - }, - "atom" ".inc" ".u32" => { - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Inc, - typ: ast::ScalarType::U32 - } - }; - ast::Instruction::Atom(details,a) - }, - "atom" ".dec" ".u32" => { - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Dec, - typ: ast::ScalarType::U32 - } - }; - ast::Instruction::Atom(details,a) - }, - "atom" ".add" => { - let op = ast::AtomFloatOp::Add; - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Float { op, typ } - }; - ast::Instruction::Atom(details,a) - }, - "atom" => { - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Unsigned { op, typ } - }; - ast::Instruction::Atom(details,a) - }, - "atom" => { - let details = ast::AtomDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - inner: ast::AtomInnerDetails::Signed { op, typ } - }; - ast::Instruction::Atom(details,a) - } -} - -InstAtomCas: ast::Instruction> = { - "atom" ".cas" => { - let details = ast::AtomCasDetails { - semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), - scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::StateSpace::Generic), - typ, - }; - ast::Instruction::AtomCas(details,a) - }, -} - -AtomSemantics: ast::AtomSemantics = { - ".relaxed" => ast::AtomSemantics::Relaxed, - ".acquire" => ast::AtomSemantics::Acquire, - ".release" => ast::AtomSemantics::Release, - ".acq_rel" => ast::AtomSemantics::AcquireRelease -} - -AtomSpace: ast::StateSpace = { - ".global" => ast::StateSpace::Global, - ".shared" => ast::StateSpace::Shared -} - -AtomBitOp: ast::AtomBitOp = { - ".and" => ast::AtomBitOp::And, - ".or" => ast::AtomBitOp::Or, - ".xor" => ast::AtomBitOp::Xor, - ".exch" => ast::AtomBitOp::Exchange, -} - -AtomUIntOp: ast::AtomUIntOp = { - ".add" => ast::AtomUIntOp::Add, - ".min" => ast::AtomUIntOp::Min, - ".max" => ast::AtomUIntOp::Max, -} - -AtomSIntOp: ast::AtomSIntOp = { - ".add" => ast::AtomSIntOp::Add, - ".min" => ast::AtomSIntOp::Min, - ".max" => ast::AtomSIntOp::Max, -} - -BitType: ast::ScalarType = { - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, -} - -UIntType3264: ast::ScalarType = { - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, -} - -SIntType3264: ast::ScalarType = { - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div -InstDiv: ast::Instruction> = { - "div" => ast::Instruction::Div(ast::DivDetails::Unsigned(t), a), - "div" => ast::Instruction::Div(ast::DivDetails::Signed(t), a), - "div" ".f32" => { - let inner = ast::DivFloatDetails { - typ: ast::ScalarType::F32, - flush_to_zero: Some(ftz.is_some()), - kind - }; - ast::Instruction::Div(ast::DivDetails::Float(inner), a) - }, - "div" ".f64" => { - let inner = ast::DivFloatDetails { - typ: ast::ScalarType::F64, - flush_to_zero: None, - kind: ast::DivFloatKind::Rounding(rnd) - }; - ast::Instruction::Div(ast::DivDetails::Float(inner), a) - }, -} - -DivFloatKind: ast::DivFloatKind = { - ".approx" => ast::DivFloatKind::Approx, - ".full" => ast::DivFloatKind::Full, - => ast::DivFloatKind::Rounding(rnd), -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt -InstSqrt: ast::Instruction> = { - "sqrt" ".approx" ".f32" => { - let details = ast::SqrtDetails { - typ: ast::ScalarType::F32, - flush_to_zero: Some(ftz.is_some()), - kind: ast::SqrtKind::Approx, - }; - ast::Instruction::Sqrt(details, a) - }, - "sqrt" ".f32" => { - let details = ast::SqrtDetails { - typ: ast::ScalarType::F32, - flush_to_zero: Some(ftz.is_some()), - kind: ast::SqrtKind::Rounding(rnd), - }; - ast::Instruction::Sqrt(details, a) - }, - "sqrt" ".f64" => { - let details = ast::SqrtDetails { - typ: ast::ScalarType::F64, - flush_to_zero: None, - kind: ast::SqrtKind::Rounding(rnd), - }; - ast::Instruction::Sqrt(details, a) - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 -InstRsqrt: ast::Instruction> = { - "rsqrt" ".approx" ".f32" => { - let details = ast::RsqrtDetails { - typ: ast::ScalarType::F32, - flush_to_zero: ftz.is_some(), - }; - ast::Instruction::Rsqrt(details, a) - }, - "rsqrt" ".approx" ".f64" => { - let details = ast::RsqrtDetails { - typ: ast::ScalarType::F64, - flush_to_zero: ftz.is_some(), - }; - ast::Instruction::Rsqrt(details, a) - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg -InstNeg: ast::Instruction> = { - "neg" => { - let details = ast::NegDetails { - typ, - flush_to_zero: Some(ftz.is_some()), - }; - ast::Instruction::Neg(details, a) - }, - "neg" => { - let details = ast::NegDetails { - typ, - flush_to_zero: None, - }; - ast::Instruction::Neg(details, a) - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin -InstSin: ast::Instruction> = { - "sin" ".approx" ".f32" => { - ast::Instruction::Sin{ flush_to_zero: ftz.is_some(), arg } - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos -InstCos: ast::Instruction> = { - "cos" ".approx" ".f32" => { - ast::Instruction::Cos{ flush_to_zero: ftz.is_some(), arg } - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2 -InstLg2: ast::Instruction> = { - "lg2" ".approx" ".f32" => { - ast::Instruction::Lg2{ flush_to_zero: ftz.is_some(), arg } - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2 -InstEx2: ast::Instruction> = { - "ex2" ".approx" ".f32" => { - ast::Instruction::Ex2{ flush_to_zero: ftz.is_some(), arg } - }, -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz -InstClz: ast::Instruction> = { - "clz" => ast::Instruction::Clz{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev -InstBrev: ast::Instruction> = { - "brev" => ast::Instruction::Brev{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc -InstPopc: ast::Instruction> = { - "popc" => ast::Instruction::Popc{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor -InstXor: ast::Instruction> = { - "xor" => ast::Instruction::Xor{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe -InstBfe: ast::Instruction> = { - "bfe" => ast::Instruction::Bfe{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi -InstBfi: ast::Instruction> = { - "bfi" => ast::Instruction::Bfi{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt -InstPrmt: ast::Instruction> = { - "prmt" ".b32" "," => ast::Instruction::Prmt{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem -InstRem: ast::Instruction> = { - "rem" => ast::Instruction::Rem{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask -InstActivemask: ast::Instruction> = { - "activemask" ".b32" => ast::Instruction::Activemask{ <> } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar -InstMembar: ast::Instruction> = { - "membar" => ast::Instruction::Membar{ <> } -} - -NegTypeFtz: ast::ScalarType = { - ".f16" => ast::ScalarType::F16, - ".f16x2" => ast::ScalarType::F16x2, - ".f32" => ast::ScalarType::F32, -} - -NegTypeNonFtz: ast::ScalarType = { - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f64" => ast::ScalarType::F64 -} - -ArithDetails: ast::ArithDetails = { - => ast::ArithDetails::Unsigned(t), - => ast::ArithDetails::Signed(ast::ArithSInt { - typ: t, - saturate: false, - }), - ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S32, - saturate: true, - }), - => ast::ArithDetails::Float(f) -} - -ArithFloat: ast::ArithFloat = { - ".f32" => ast::ArithFloat { - typ: ast::ScalarType::F32, - rounding: rn, - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, - ".f64" => ast::ArithFloat { - typ: ast::ScalarType::F64, - rounding: rn, - flush_to_zero: None, - saturate: false, - }, - ".f16" => ast::ArithFloat { - typ: ast::ScalarType::F16, - rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, - ".f16x2" => ast::ArithFloat { - typ: ast::ScalarType::F16x2, - rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, -} - -ArithFloatMustRound: ast::ArithFloat = { - ".f32" => ast::ArithFloat { - typ: ast::ScalarType::F32, - rounding: Some(rn), - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, - ".f64" => ast::ArithFloat { - typ: ast::ScalarType::F64, - rounding: Some(rn), - flush_to_zero: None, - saturate: false, - }, - ".rn" ".f16" => ast::ArithFloat { - typ: ast::ScalarType::F16, - rounding: Some(ast::RoundingMode::NearestEven), - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, - ".rn" ".f16x2" => ast::ArithFloat { - typ: ast::ScalarType::F16x2, - rounding: Some(ast::RoundingMode::NearestEven), - flush_to_zero: Some(ftz.is_some()), - saturate: sat.is_some(), - }, -} - -Operand: ast::Operand<&'input str> = { - => ast::Operand::Reg(r), - "+" => ast::Operand::RegOffset(r, offset), - => ast::Operand::Imm(x) -}; - -CallOperand: ast::Operand<&'input str> = { - => ast::Operand::Reg(r), - => ast::Operand::Imm(x) -}; - -// TODO: start parsing whole constants sub-language: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants -ImmediateValue: ast::ImmediateValue = { - // TODO: treat negation correctly - => { - let (num, radix, is_unsigned) = x; - if neg.is_some() { - match i64::from_str_radix(num, radix) { - Ok(x) => ast::ImmediateValue::S64(-x), - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - ast::ImmediateValue::S64(0) - } - } - } else if is_unsigned { - match u64::from_str_radix(num, radix) { - Ok(x) => ast::ImmediateValue::U64(x), - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - ast::ImmediateValue::U64(0) - } - } - } else { - match i64::from_str_radix(num, radix) { - Ok(x) => ast::ImmediateValue::S64(x), - Err(_) => { - match u64::from_str_radix(num, radix) { - Ok(x) => ast::ImmediateValue::U64(x), - Err(err) => { - errors.push(ParseError::User { error: ast::PtxError::from(err) }); - ast::ImmediateValue::U64(0) - } - } - } - } - } - }, - => { - ast::ImmediateValue::F32(f) - }, - => { - ast::ImmediateValue::F64(f) - } -} - -Arg1: ast::Arg1> = { - => ast::Arg1{<>} -}; - -Arg1Bar: ast::Arg1Bar> = { - => ast::Arg1Bar{<>} -}; - -Arg2: ast::Arg2> = { - "," => ast::Arg2{<>} -}; - -MemberOperand: (&'input str, u8) = { - "." => { - let suf_idx = match vector_index(suf) { - Ok(x) => x, - Err(err) => { - errors.push(err); - 0 - } - }; - (pref, suf_idx) - }, - => { - let suf_idx = match vector_index(&suf[1..]) { - Ok(x) => x, - Err(err) => { - errors.push(err); - 0 - } - }; - (pref, suf_idx) - } -}; - -VectorExtract: Vec<&'input str> = { - "{" "," "}" => { - vec![r1, r2] - }, - "{" "," "," "," "}" => { - vec![r1, r2, r3, r4] - }, -}; - -Arg3: ast::Arg3> = { - "," "," => ast::Arg3{<>} -}; - -Arg3Atom: ast::Arg3> = { - "," "[" "]" "," => ast::Arg3{<>} -}; - -Arg4: ast::Arg4> = { - "," "," "," => ast::Arg4{<>} -}; - -Arg4Atom: ast::Arg4> = { - "," "[" "]" "," "," => ast::Arg4{<>} -}; - -Arg4Setp: ast::Arg4Setp> = { - "," "," => ast::Arg4Setp{<>} -}; - -Arg5: ast::Arg5> = { - "," "," "," "," => ast::Arg5{<>} -}; - -// TODO: pass src3 negation somewhere -Arg5Setp: ast::Arg5Setp> = { - "," "," "," "!"? => ast::Arg5Setp{<>} -}; - -ArgCall: (Vec<&'input str>, &'input str, Vec>) = { - "(" > ")" "," "," "(" > ")" => { - (ret_params, func, param_list) - }, - "(" > ")" "," => { - (ret_params, func, Vec::new()) - }, - "," "(" > ")" => (Vec::new(), func, param_list), - => (Vec::new(), func, Vec::>::new()), -}; - -OptionalDst: &'input str = { - "|" => dst2 -} - -SrcOperand: ast::Operand<&'input str> = { - => ast::Operand::Reg(r), - "+" => ast::Operand::RegOffset(r, offset), - => ast::Operand::Imm(x), - => { - let (reg, idx) = mem_op; - ast::Operand::VecMember(reg, idx) - } -} - -SrcOperandVec: ast::Operand<&'input str> = { - => normal, - => ast::Operand::VecPack(vec), -} - -DstOperand: ast::Operand<&'input str> = { - => ast::Operand::Reg(r), - => { - let (reg, idx) = mem_op; - ast::Operand::VecMember(reg, idx) - } -} - -DstOperandVec: ast::Operand<&'input str> = { - => normal, - => ast::Operand::VecPack(vec), -} - -VectorPrefix: u8 = { - ".v2" => 2, - ".v4" => 4 -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file -File = { - ".file" U32Num String ("," U32Num "," U32Num)? -}; - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-section -Section = { - ".section" DotID "{" SectionDwarfLines* "}" -}; - -SectionDwarfLines: () = { - AnyBitType Comma, - ".b32" SectionLabel, - ".b64" SectionLabel, - ".b32" SectionLabel "+" U32Num, - ".b64" SectionLabel "+" U32Num, -}; - -SectionLabel = { - ID, - DotID -}; - -AnyBitType = { - ".b8", ".b16", ".b32", ".b64" -}; - -VariableScalar: (Option, T, &'input str) = { - => { - (align, v_type, name) - } -} - -VariableVector: (Option, u8, T, &'input str) = { - => { - (align, v_len, v_type, name) - } -} - -// empty dimensions [0] means it's a pointer -VariableArrayOrPointer: (Option, T, &'input str, ast::ArrayOrPointer) = { - => { - let mut dims = dims; - let array_init = match init { - Some(init) => { - let init_vec = match init.to_vec(typ, &mut dims) { - Err(error) => { - errors.push(ParseError::User { error }); - Vec::new() - } - Ok(x) => x - }; - ast::ArrayOrPointer::Array { dimensions: dims, init: init_vec } - } - None => { - if dims.len() > 1 && dims.contains(&0) { - errors.push(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); - } - match &*dims { - [0] => ast::ArrayOrPointer::Pointer, - _ => ast::ArrayOrPointer::Array { dimensions: dims, init: Vec::new() } - } - } - }; - (align, typ, name, array_init) - } -} - -// [0] and [] are treated the same -ArrayDimensions: Vec = { - ArrayEmptyDimension => vec![0u32], - ArrayEmptyDimension => { - let mut dims = dims; - let mut result = vec![0u32]; - result.append(&mut dims); - result - }, - => dims -} - -ArrayEmptyDimension = { - "[" "]" -} - -ArrayDimension: u32 = { - "[" "]" => n, -} - -ArrayInitializer: ast::NumsOrArrays<'input> = { - "=" => nums -} - -NumsOrArraysBracket: ast::NumsOrArrays<'input> = { - "{" "}" => nums -} - -NumsOrArrays: ast::NumsOrArrays<'input> = { - > => ast::NumsOrArrays::Arrays(n), - > => ast::NumsOrArrays::Nums(n.into_iter().map(|(x,radix,_)| (x, radix)).collect()), -} - -Comma: Vec = { - ",")*> => match e { - None => v, - Some(e) => { - let mut v = v; - v.push(e); - v - } - } -}; - -CommaNonEmpty: Vec = { - ",")*> => { - let mut v = v; - v.push(e); - v - } -}; - -#[inline] -Or: T1 = { - T1, - T2 -} - -#[inline] -Or3: T1 = { - T1, - T2, - T3 -} \ No newline at end of file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 0785f3e..e9943f4 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,18 +1,15 @@ -use super::ptx; -use super::TranslateError; +use crate::pass::TranslateError; +use ptx_parser as ast; mod spirv_run; -fn parse_and_assert(s: &str) { - let mut errors = Vec::new(); - ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); - assert!(errors.len() == 0); +fn parse_and_assert(ptx_text: &str) { + ast::parse_module_checked(ptx_text).unwrap(); } -fn compile_and_assert(s: &str) -> Result<(), TranslateError> { - let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); - crate::to_spirv_module(ast)?; +fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> { + let ast = ast::parse_module_checked(ptx_text).unwrap(); + crate::to_llvm_module(ast)?; Ok(()) } diff --git a/ptx/src/test/spirv_run/activemask.spvtxt b/ptx/src/test/spirv_run/activemask.spvtxt deleted file mode 100644 index 0753c95..0000000 --- a/ptx/src/test/spirv_run/activemask.spvtxt +++ /dev/null @@ -1,45 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %18 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "activemask" - OpExecutionMode %1 ContractionOff - OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %21 = OpTypeFunction %uint - %ulong = OpTypeInt 64 0 - %23 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %15 = OpFunction %uint None %21 - OpFunctionEnd - %1 = OpFunction %void None %23 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %14 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_uint Function - OpStore %2 %6 - OpStore %3 %7 - %8 = OpLoad %ulong %3 Aligned 8 - OpStore %4 %8 - %9 = OpFunctionCall %uint %15 - OpStore %5 %9 - %10 = OpLoad %ulong %4 - %11 = OpLoad %uint %5 - %12 = OpConvertUToPtr %_ptr_Generic_uint %10 - %13 = OpCopyObject %uint %11 - OpStore %12 %13 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/add.spvtxt b/ptx/src/test/spirv_run/add.spvtxt deleted file mode 100644 index b468693..0000000 --- a/ptx/src/test/spirv_run/add.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %26 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpIAdd %ulong %15 %ulong_1 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/add_non_coherent.spvtxt b/ptx/src/test/spirv_run/add_non_coherent.spvtxt deleted file mode 100644 index 99da980..0000000 --- a/ptx/src/test/spirv_run/add_non_coherent.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add_non_coherent" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %26 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpIAdd %ulong %15 %ulong_1 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - OpStore %20 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/add_tuning.spvtxt b/ptx/src/test/spirv_run/add_tuning.spvtxt deleted file mode 100644 index d65f04d..0000000 --- a/ptx/src/test/spirv_run/add_tuning.spvtxt +++ /dev/null @@ -1,55 +0,0 @@ -; SPIR-V -; Version: 1.3 -; Generator: rspirv -; Bound: 29 -OpCapability GenericPointer -OpCapability Linkage -OpCapability Addresses -OpCapability Kernel -OpCapability Int8 -OpCapability Int16 -OpCapability Int64 -OpCapability Float16 -OpCapability Float64 -OpCapability DenormFlushToZero -%23 = OpExtInstImport "OpenCL.std" -OpMemoryModel Physical64 OpenCL -OpEntryPoint Kernel %1 "add_tuning" -OpExecutionMode %1 ContractionOff -; OpExecutionMode %1 MaxWorkgroupSizeINTEL 256 1 1 -OpDecorate %1 LinkageAttributes "add_tuning" Export -%24 = OpTypeVoid -%25 = OpTypeInt 64 0 -%26 = OpTypeFunction %24 %25 %25 -%27 = OpTypePointer Function %25 -%28 = OpTypePointer Generic %25 -%18 = OpConstant %25 1 -%1 = OpFunction %24 None %26 -%8 = OpFunctionParameter %25 -%9 = OpFunctionParameter %25 -%21 = OpLabel -%2 = OpVariable %27 Function -%3 = OpVariable %27 Function -%4 = OpVariable %27 Function -%5 = OpVariable %27 Function -%6 = OpVariable %27 Function -%7 = OpVariable %27 Function -OpStore %2 %8 -OpStore %3 %9 -%10 = OpLoad %25 %2 Aligned 8 -OpStore %4 %10 -%11 = OpLoad %25 %3 Aligned 8 -OpStore %5 %11 -%13 = OpLoad %25 %4 -%19 = OpConvertUToPtr %28 %13 -%12 = OpLoad %25 %19 Aligned 8 -OpStore %6 %12 -%15 = OpLoad %25 %6 -%14 = OpIAdd %25 %15 %18 -OpStore %7 %14 -%16 = OpLoad %25 %5 -%17 = OpLoad %25 %7 -%20 = OpConvertUToPtr %28 %16 -OpStore %20 %17 Aligned 8 -OpReturn -OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt deleted file mode 100644 index f66639a..0000000 --- a/ptx/src/test/spirv_run/and.spvtxt +++ /dev/null @@ -1,62 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %31 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "and" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %34 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %34 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %29 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %41 = OpBitcast %_ptr_Generic_uchar %24 - %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %42 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %26 = OpCopyObject %uint %17 - %27 = OpCopyObject %uint %18 - %25 = OpBitwiseAnd %uint %26 %27 - %16 = OpCopyObject %uint %25 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %28 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %28 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/assertfail.spvtxt b/ptx/src/test/spirv_run/assertfail.spvtxt deleted file mode 100644 index 8ed84fa..0000000 --- a/ptx/src/test/spirv_run/assertfail.spvtxt +++ /dev/null @@ -1,105 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %67 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %12 "assertfail" - OpDecorate %1 LinkageAttributes "__zluda_ptx_impl____assertfail" Import - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %73 = OpTypeFunction %void %_ptr_Function_ulong %_ptr_Function_ulong %_ptr_Function_uint %_ptr_Function_ulong %_ptr_Function_ulong - %74 = OpTypeFunction %void %ulong %ulong - %uint_0 = OpConstant %uint 0 - %ulong_0 = OpConstant %ulong 0 - %uchar = OpTypeInt 8 0 -%_ptr_Function_uchar = OpTypePointer Function %uchar - %ulong_0_0 = OpConstant %ulong 0 - %ulong_0_1 = OpConstant %ulong 0 - %ulong_0_2 = OpConstant %ulong 0 - %ulong_0_3 = OpConstant %ulong 0 -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %73 - %61 = OpFunctionParameter %_ptr_Function_ulong - %62 = OpFunctionParameter %_ptr_Function_ulong - %63 = OpFunctionParameter %_ptr_Function_uint - %64 = OpFunctionParameter %_ptr_Function_ulong - %65 = OpFunctionParameter %_ptr_Function_ulong - OpFunctionEnd - %12 = OpFunction %void None %74 - %25 = OpFunctionParameter %ulong - %26 = OpFunctionParameter %ulong - %60 = OpLabel - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - %16 = OpVariable %_ptr_Function_ulong Function - %17 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_uint Function - %20 = OpVariable %_ptr_Function_ulong Function - %21 = OpVariable %_ptr_Function_ulong Function - %22 = OpVariable %_ptr_Function_uint Function - %23 = OpVariable %_ptr_Function_ulong Function - %24 = OpVariable %_ptr_Function_ulong Function - OpStore %13 %25 - OpStore %14 %26 - %27 = OpLoad %ulong %13 Aligned 8 - OpStore %15 %27 - %28 = OpLoad %ulong %14 Aligned 8 - OpStore %16 %28 - %53 = OpCopyObject %uint %uint_0 - %29 = OpCopyObject %uint %53 - OpStore %19 %29 - %30 = OpLoad %ulong %15 - %77 = OpBitcast %_ptr_Function_uchar %20 - %78 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %77 %ulong_0 - %43 = OpBitcast %_ptr_Function_ulong %78 - %54 = OpCopyObject %ulong %30 - OpStore %43 %54 Aligned 8 - %31 = OpLoad %ulong %15 - %79 = OpBitcast %_ptr_Function_uchar %21 - %80 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %79 %ulong_0_0 - %45 = OpBitcast %_ptr_Function_ulong %80 - %55 = OpCopyObject %ulong %31 - OpStore %45 %55 Aligned 8 - %32 = OpLoad %uint %19 - %81 = OpBitcast %_ptr_Function_uchar %22 - %82 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %81 %ulong_0_1 - %47 = OpBitcast %_ptr_Function_uint %82 - OpStore %47 %32 Aligned 4 - %33 = OpLoad %ulong %15 - %83 = OpBitcast %_ptr_Function_uchar %23 - %84 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %83 %ulong_0_2 - %49 = OpBitcast %_ptr_Function_ulong %84 - %56 = OpCopyObject %ulong %33 - OpStore %49 %56 Aligned 8 - %34 = OpLoad %ulong %15 - %85 = OpBitcast %_ptr_Function_uchar %24 - %86 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %85 %ulong_0_3 - %51 = OpBitcast %_ptr_Function_ulong %86 - %57 = OpCopyObject %ulong %34 - OpStore %51 %57 Aligned 8 - %87 = OpFunctionCall %void %1 %20 %21 %22 %23 %24 - %36 = OpLoad %ulong %15 - %58 = OpConvertUToPtr %_ptr_Generic_ulong %36 - %35 = OpLoad %ulong %58 Aligned 8 - OpStore %17 %35 - %38 = OpLoad %ulong %17 - %37 = OpIAdd %ulong %38 %ulong_1 - OpStore %18 %37 - %39 = OpLoad %ulong %16 - %40 = OpLoad %ulong %18 - %59 = OpConvertUToPtr %_ptr_Generic_ulong %39 - OpStore %59 %40 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt deleted file mode 100644 index 987fdef..0000000 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ /dev/null @@ -1,85 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %38 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add" %4 - OpExecutionMode %1 ContractionOff - OpDecorate %4 Alignment 4 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_1024 = OpConstant %uint 1024 -%_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 -%_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup - %ulong = OpTypeInt 64 0 - %46 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %uint_1 = OpConstant %uint 1 - %uint_0 = OpConstant %uint 0 - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %46 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %36 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %12 - %14 = OpLoad %ulong %5 - %29 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %29 Aligned 4 - OpStore %7 %13 - %16 = OpLoad %ulong %5 - %30 = OpConvertUToPtr %_ptr_Generic_uint %16 - %51 = OpBitcast %_ptr_Generic_uchar %30 - %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 - %26 = OpBitcast %_ptr_Generic_uint %52 - %15 = OpLoad %uint %26 Aligned 4 - OpStore %8 %15 - %17 = OpLoad %uint %7 - %31 = OpBitcast %_ptr_Workgroup_uint %4 - OpStore %31 %17 Aligned 4 - %19 = OpLoad %uint %8 - %32 = OpBitcast %_ptr_Workgroup_uint %4 - %18 = OpAtomicIAdd %uint %32 %uint_1 %uint_0 %19 - OpStore %7 %18 - %33 = OpBitcast %_ptr_Workgroup_uint %4 - %20 = OpLoad %uint %33 Aligned 4 - OpStore %8 %20 - %21 = OpLoad %ulong %6 - %22 = OpLoad %uint %7 - %34 = OpConvertUToPtr %_ptr_Generic_uint %21 - OpStore %34 %22 Aligned 4 - %23 = OpLoad %ulong %6 - %24 = OpLoad %uint %8 - %35 = OpConvertUToPtr %_ptr_Generic_uint %23 - %56 = OpBitcast %_ptr_Generic_uchar %35 - %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_uint %57 - OpStore %28 %24 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt deleted file mode 100644 index 067c347..0000000 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ /dev/null @@ -1,90 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %42 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add_float" %4 - OpExecutionMode %1 ContractionOff - OpDecorate %37 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_shared_add_f32" Import - OpDecorate %4 Alignment 4 - %void = OpTypeVoid - %float = OpTypeFloat 32 -%_ptr_Workgroup_float = OpTypePointer Workgroup %float - %46 = OpTypeFunction %float %_ptr_Workgroup_float %float - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_1024 = OpConstant %uint 1024 -%_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 -%_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup - %ulong = OpTypeInt 64 0 - %53 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_4_0 = OpConstant %ulong 4 - %37 = OpFunction %float None %46 - %39 = OpFunctionParameter %_ptr_Workgroup_float - %40 = OpFunctionParameter %float - OpFunctionEnd - %1 = OpFunction %void None %53 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %36 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %12 - %14 = OpLoad %ulong %5 - %29 = OpConvertUToPtr %_ptr_Generic_float %14 - %13 = OpLoad %float %29 Aligned 4 - OpStore %7 %13 - %16 = OpLoad %ulong %5 - %30 = OpConvertUToPtr %_ptr_Generic_float %16 - %58 = OpBitcast %_ptr_Generic_uchar %30 - %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 - %26 = OpBitcast %_ptr_Generic_float %59 - %15 = OpLoad %float %26 Aligned 4 - OpStore %8 %15 - %17 = OpLoad %float %7 - %31 = OpBitcast %_ptr_Workgroup_float %4 - OpStore %31 %17 Aligned 4 - %19 = OpLoad %float %8 - %32 = OpBitcast %_ptr_Workgroup_float %4 - %18 = OpFunctionCall %float %37 %32 %19 - OpStore %7 %18 - %33 = OpBitcast %_ptr_Workgroup_float %4 - %20 = OpLoad %float %33 Aligned 4 - OpStore %8 %20 - %21 = OpLoad %ulong %6 - %22 = OpLoad %float %7 - %34 = OpConvertUToPtr %_ptr_Generic_float %21 - OpStore %34 %22 Aligned 4 - %23 = OpLoad %ulong %6 - %24 = OpLoad %float %8 - %35 = OpConvertUToPtr %_ptr_Generic_float %23 - %60 = OpBitcast %_ptr_Generic_uchar %35 - %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_float %61 - OpStore %28 %24 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt deleted file mode 100644 index 7c2f4fa..0000000 --- a/ptx/src/test/spirv_run/atom_cas.spvtxt +++ /dev/null @@ -1,77 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %39 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_cas" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %42 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %uint_100 = OpConstant %uint 100 - %uint_1 = OpConstant %uint 1 - %uint_0 = OpConstant %uint 0 - %ulong_4_0 = OpConstant %ulong 4 - %ulong_4_1 = OpConstant %ulong 4 - %1 = OpFunction %void None %42 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %37 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %30 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %16 = OpLoad %uint %6 - %31 = OpConvertUToPtr %_ptr_Generic_uint %15 - %49 = OpBitcast %_ptr_Generic_uchar %31 - %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_4 - %24 = OpBitcast %_ptr_Generic_uint %50 - %33 = OpCopyObject %uint %16 - %32 = OpAtomicCompareExchange %uint %24 %uint_1 %uint_0 %uint_0 %uint_100 %33 - %14 = OpCopyObject %uint %32 - OpStore %6 %14 - %18 = OpLoad %ulong %4 - %34 = OpConvertUToPtr %_ptr_Generic_uint %18 - %53 = OpBitcast %_ptr_Generic_uchar %34 - %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4_0 - %27 = OpBitcast %_ptr_Generic_uint %54 - %17 = OpLoad %uint %27 Aligned 4 - OpStore %7 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %35 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %35 %20 Aligned 4 - %21 = OpLoad %ulong %5 - %22 = OpLoad %uint %7 - %36 = OpConvertUToPtr %_ptr_Generic_uint %21 - %55 = OpBitcast %_ptr_Generic_uchar %36 - %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4_1 - %29 = OpBitcast %_ptr_Generic_uint %56 - OpStore %29 %22 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt deleted file mode 100644 index 4855cd4..0000000 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ /dev/null @@ -1,87 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %47 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_inc" - OpDecorate %38 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import - OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint - %ulong = OpTypeInt 64 0 - %55 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint - %uint_101 = OpConstant %uint 101 - %uint_101_0 = OpConstant %uint 101 - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %38 = OpFunction %uint None %51 - %40 = OpFunctionParameter %_ptr_Generic_uint - %41 = OpFunctionParameter %uint - OpFunctionEnd - %42 = OpFunction %uint None %53 - %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint - %45 = OpFunctionParameter %uint - OpFunctionEnd - %1 = OpFunction %void None %55 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %37 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %31 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpFunctionCall %uint %38 %31 %uint_101 - OpStore %6 %13 - %16 = OpLoad %ulong %4 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 - %15 = OpFunctionCall %uint %42 %32 %uint_101_0 - OpStore %7 %15 - %18 = OpLoad %ulong %4 - %33 = OpConvertUToPtr %_ptr_Generic_uint %18 - %17 = OpLoad %uint %33 Aligned 4 - OpStore %8 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %34 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %34 %20 Aligned 4 - %21 = OpLoad %ulong %5 - %22 = OpLoad %uint %7 - %35 = OpConvertUToPtr %_ptr_Generic_uint %21 - %60 = OpBitcast %_ptr_Generic_uchar %35 - %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4 - %28 = OpBitcast %_ptr_Generic_uint %61 - OpStore %28 %22 Aligned 4 - %23 = OpLoad %ulong %5 - %24 = OpLoad %uint %8 - %36 = OpConvertUToPtr %_ptr_Generic_uint %23 - %62 = OpBitcast %_ptr_Generic_uchar %36 - %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_8 - %30 = OpBitcast %_ptr_Generic_uint %63 - OpStore %30 %24 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/b64tof64.spvtxt b/ptx/src/test/spirv_run/b64tof64.spvtxt deleted file mode 100644 index 54ac111..0000000 --- a/ptx/src/test/spirv_run/b64tof64.spvtxt +++ /dev/null @@ -1,50 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "b64tof64" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %double = OpTypeFloat 64 -%_ptr_Function_double = OpTypePointer Function %double -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_double Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %18 = OpBitcast %_ptr_Function_double %2 - %10 = OpLoad %double %18 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %double %4 - %19 = OpBitcast %ulong %13 - %12 = OpCopyObject %ulong %19 - OpStore %5 %12 - %15 = OpLoad %ulong %5 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %14 = OpLoad %ulong %20 Aligned 8 - OpStore %7 %14 - %16 = OpLoad %ulong %6 - %17 = OpLoad %ulong %7 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %21 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt deleted file mode 100644 index 0001808..0000000 --- a/ptx/src/test/spirv_run/bfe.spvtxt +++ /dev/null @@ -1,76 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %40 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "bfe" - OpDecorate %34 LinkageAttributes "__zluda_ptx_impl__bfe_u32" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %43 = OpTypeFunction %uint %uint %uint %uint - %ulong = OpTypeInt 64 0 - %45 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %34 = OpFunction %uint None %43 - %36 = OpFunctionParameter %uint - %37 = OpFunctionParameter %uint - %38 = OpFunctionParameter %uint - OpFunctionEnd - %1 = OpFunction %void None %45 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %33 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %29 Aligned 4 - OpStore %6 %13 - %16 = OpLoad %ulong %4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %16 - %51 = OpBitcast %_ptr_Generic_uchar %30 - %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 - %26 = OpBitcast %_ptr_Generic_uint %52 - %15 = OpLoad %uint %26 Aligned 4 - OpStore %7 %15 - %18 = OpLoad %ulong %4 - %31 = OpConvertUToPtr %_ptr_Generic_uint %18 - %53 = OpBitcast %_ptr_Generic_uchar %31 - %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_8 - %28 = OpBitcast %_ptr_Generic_uint %54 - %17 = OpLoad %uint %28 Aligned 4 - OpStore %8 %17 - %20 = OpLoad %uint %6 - %21 = OpLoad %uint %7 - %22 = OpLoad %uint %8 - %19 = OpFunctionCall %uint %34 %20 %21 %22 - OpStore %6 %19 - %23 = OpLoad %ulong %5 - %24 = OpLoad %uint %6 - %32 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %32 %24 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt deleted file mode 100644 index 1979939..0000000 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ /dev/null @@ -1,90 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %51 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "bfi" - OpDecorate %44 LinkageAttributes "__zluda_ptx_impl__bfi_b32" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %54 = OpTypeFunction %uint %uint %uint %uint %uint - %ulong = OpTypeInt 64 0 - %56 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %ulong_12 = OpConstant %ulong 12 - %44 = OpFunction %uint None %54 - %46 = OpFunctionParameter %uint - %47 = OpFunctionParameter %uint - %48 = OpFunctionParameter %uint - %49 = OpFunctionParameter %uint - OpFunctionEnd - %1 = OpFunction %void None %56 - %10 = OpFunctionParameter %ulong - %11 = OpFunctionParameter %ulong - %43 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - %9 = OpVariable %_ptr_Function_uint Function - OpStore %2 %10 - OpStore %3 %11 - %12 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %12 - %13 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %13 - %15 = OpLoad %ulong %4 - %35 = OpConvertUToPtr %_ptr_Generic_uint %15 - %14 = OpLoad %uint %35 Aligned 4 - OpStore %6 %14 - %17 = OpLoad %ulong %4 - %36 = OpConvertUToPtr %_ptr_Generic_uint %17 - %62 = OpBitcast %_ptr_Generic_uchar %36 - %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 - %30 = OpBitcast %_ptr_Generic_uint %63 - %16 = OpLoad %uint %30 Aligned 4 - OpStore %7 %16 - %19 = OpLoad %ulong %4 - %37 = OpConvertUToPtr %_ptr_Generic_uint %19 - %64 = OpBitcast %_ptr_Generic_uchar %37 - %65 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %64 %ulong_8 - %32 = OpBitcast %_ptr_Generic_uint %65 - %18 = OpLoad %uint %32 Aligned 4 - OpStore %8 %18 - %21 = OpLoad %ulong %4 - %38 = OpConvertUToPtr %_ptr_Generic_uint %21 - %66 = OpBitcast %_ptr_Generic_uchar %38 - %67 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %66 %ulong_12 - %34 = OpBitcast %_ptr_Generic_uint %67 - %20 = OpLoad %uint %34 Aligned 4 - OpStore %9 %20 - %23 = OpLoad %uint %6 - %24 = OpLoad %uint %7 - %25 = OpLoad %uint %8 - %26 = OpLoad %uint %9 - %40 = OpCopyObject %uint %23 - %41 = OpCopyObject %uint %24 - %39 = OpFunctionCall %uint %44 %40 %41 %25 %26 - %22 = OpCopyObject %uint %39 - OpStore %6 %22 - %27 = OpLoad %ulong %5 - %28 = OpLoad %uint %6 - %42 = OpConvertUToPtr %_ptr_Generic_uint %27 - OpStore %42 %28 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/block.spvtxt b/ptx/src/test/spirv_run/block.spvtxt deleted file mode 100644 index 6921c04..0000000 --- a/ptx/src/test/spirv_run/block.spvtxt +++ /dev/null @@ -1,52 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %27 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "block" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %30 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %ulong_1_0 = OpConstant %ulong 1 - %1 = OpFunction %void None %30 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %25 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_ulong %14 - %13 = OpLoad %ulong %23 Aligned 8 - OpStore %6 %13 - %16 = OpLoad %ulong %6 - %15 = OpIAdd %ulong %16 %ulong_1 - OpStore %7 %15 - %18 = OpLoad %ulong %8 - %17 = OpIAdd %ulong %18 %ulong_1_0 - OpStore %8 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %ulong %7 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %19 - OpStore %24 %20 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bra.spvtxt b/ptx/src/test/spirv_run/bra.spvtxt deleted file mode 100644 index c2c1e1c..0000000 --- a/ptx/src/test/spirv_run/bra.spvtxt +++ /dev/null @@ -1,57 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %29 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "bra" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %ulong_2 = OpConstant %ulong 2 - %1 = OpFunction %void None %32 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %27 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %11 - OpStore %3 %12 - %13 = OpLoad %ulong %2 Aligned 8 - OpStore %7 %13 - %14 = OpLoad %ulong %3 Aligned 8 - OpStore %8 %14 - %16 = OpLoad %ulong %7 - %25 = OpConvertUToPtr %_ptr_Generic_ulong %16 - %15 = OpLoad %ulong %25 Aligned 8 - OpStore %9 %15 - OpBranch %4 - %4 = OpLabel - %18 = OpLoad %ulong %9 - %17 = OpIAdd %ulong %18 %ulong_1 - OpStore %10 %17 - OpBranch %6 - %35 = OpLabel - %20 = OpLoad %ulong %9 - %19 = OpIAdd %ulong %20 %ulong_2 - OpStore %10 %19 - OpBranch %6 - %6 = OpLabel - %21 = OpLoad %ulong %8 - %22 = OpLoad %ulong %10 - %26 = OpConvertUToPtr %_ptr_Generic_ulong %21 - OpStore %26 %22 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/brev.spvtxt b/ptx/src/test/spirv_run/brev.spvtxt deleted file mode 100644 index 7341adb..0000000 --- a/ptx/src/test/spirv_run/brev.spvtxt +++ /dev/null @@ -1,52 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "brev" - OpDecorate %20 LinkageAttributes "__zluda_ptx_impl__brev_b32" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %27 = OpTypeFunction %uint %uint - %ulong = OpTypeInt 64 0 - %29 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %20 = OpFunction %uint None %27 - %22 = OpFunctionParameter %uint - OpFunctionEnd - %1 = OpFunction %void None %29 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %13 = OpFunctionCall %uint %20 %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt deleted file mode 100644 index c29984e..0000000 --- a/ptx/src/test/spirv_run/call.spvtxt +++ /dev/null @@ -1,71 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - %37 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %4 "call" - OpExecutionMode %4 ContractionOff - OpDecorate %4 LinkageAttributes "call" Export - OpDecorate %1 LinkageAttributes "incr" Export - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %40 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %44 = OpTypeFunction %void %_ptr_Function_ulong %_ptr_Function_ulong - %ulong_1 = OpConstant %ulong 1 - %4 = OpFunction %void None %40 - %12 = OpFunctionParameter %ulong - %13 = OpFunctionParameter %ulong - %26 = OpLabel - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %5 %12 - OpStore %6 %13 - %14 = OpLoad %ulong %5 Aligned 8 - OpStore %7 %14 - %15 = OpLoad %ulong %6 Aligned 8 - OpStore %8 %15 - %17 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 - %16 = OpLoad %ulong %22 Aligned 8 - OpStore %9 %16 - %18 = OpLoad %ulong %9 - %23 = OpBitcast %_ptr_Function_ulong %10 - %24 = OpCopyObject %ulong %18 - OpStore %23 %24 Aligned 8 - %43 = OpFunctionCall %void %1 %10 %11 - %19 = OpLoad %ulong %11 Aligned 8 - OpStore %9 %19 - %20 = OpLoad %ulong %8 - %21 = OpLoad %ulong %9 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20 - OpStore %25 %21 Aligned 8 - OpReturn - OpFunctionEnd - %1 = OpFunction %void None %44 - %28 = OpFunctionParameter %_ptr_Function_ulong - %27 = OpFunctionParameter %_ptr_Function_ulong - %35 = OpLabel - %29 = OpVariable %_ptr_Function_ulong Function - %30 = OpLoad %ulong %28 Aligned 8 - OpStore %29 %30 - %32 = OpLoad %ulong %29 - %31 = OpIAdd %ulong %32 %ulong_1 - OpStore %29 %31 - %33 = OpLoad %ulong %29 - OpStore %27 %33 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt deleted file mode 100644 index 1feb5a0..0000000 --- a/ptx/src/test/spirv_run/clz.spvtxt +++ /dev/null @@ -1,52 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "clz" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %25 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %18 = OpExtInst %uint %22 clz %14 - %13 = OpCopyObject %uint %18 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %19 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %19 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/const.spvtxt b/ptx/src/test/spirv_run/const.spvtxt deleted file mode 100644 index 49ed9c3..0000000 --- a/ptx/src/test/spirv_run/const.spvtxt +++ /dev/null @@ -1,112 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %53 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "const" %1 - OpExecutionMode %2 ContractionOff - OpDecorate %1 Alignment 8 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %ushort = OpTypeInt 16 0 - %uint_4 = OpConstant %uint 4 -%_arr_ushort_uint_4 = OpTypeArray %ushort %uint_4 - %ushort_10 = OpConstant %ushort 10 - %ushort_20 = OpConstant %ushort 20 - %ushort_30 = OpConstant %ushort 30 - %ushort_40 = OpConstant %ushort 40 - %63 = OpConstantComposite %_arr_ushort_uint_4 %ushort_10 %ushort_20 %ushort_30 %ushort_40 - %uint_4_0 = OpConstant %uint 4 -%_ptr_UniformConstant__arr_ushort_uint_4 = OpTypePointer UniformConstant %_arr_ushort_uint_4 - %1 = OpVariable %_ptr_UniformConstant__arr_ushort_uint_4 UniformConstant %63 - %ulong = OpTypeInt 64 0 - %67 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_ushort = OpTypePointer Function %ushort -%_ptr_UniformConstant_ushort = OpTypePointer UniformConstant %ushort - %ulong_2 = OpConstant %ulong 2 - %uchar = OpTypeInt 8 0 -%_ptr_UniformConstant_uchar = OpTypePointer UniformConstant %uchar - %ulong_4 = OpConstant %ulong 4 - %ulong_6 = OpConstant %ulong 6 -%_ptr_Generic_ushort = OpTypePointer Generic %ushort - %ulong_2_0 = OpConstant %ulong 2 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_4_0 = OpConstant %ulong 4 - %ulong_6_0 = OpConstant %ulong 6 - %2 = OpFunction %void None %67 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %51 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ushort Function - %8 = OpVariable %_ptr_Function_ushort Function - %9 = OpVariable %_ptr_Function_ushort Function - %10 = OpVariable %_ptr_Function_ushort Function - OpStore %3 %11 - OpStore %4 %12 - %13 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %13 - %14 = OpLoad %ulong %4 Aligned 8 - OpStore %6 %14 - %39 = OpBitcast %_ptr_UniformConstant_ushort %1 - %15 = OpLoad %ushort %39 Aligned 2 - OpStore %7 %15 - %40 = OpBitcast %_ptr_UniformConstant_ushort %1 - %73 = OpBitcast %_ptr_UniformConstant_uchar %40 - %74 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %73 %ulong_2 - %28 = OpBitcast %_ptr_UniformConstant_ushort %74 - %16 = OpLoad %ushort %28 Aligned 2 - OpStore %8 %16 - %41 = OpBitcast %_ptr_UniformConstant_ushort %1 - %75 = OpBitcast %_ptr_UniformConstant_uchar %41 - %76 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %75 %ulong_4 - %30 = OpBitcast %_ptr_UniformConstant_ushort %76 - %17 = OpLoad %ushort %30 Aligned 2 - OpStore %9 %17 - %42 = OpBitcast %_ptr_UniformConstant_ushort %1 - %77 = OpBitcast %_ptr_UniformConstant_uchar %42 - %78 = OpInBoundsPtrAccessChain %_ptr_UniformConstant_uchar %77 %ulong_6 - %32 = OpBitcast %_ptr_UniformConstant_ushort %78 - %18 = OpLoad %ushort %32 Aligned 2 - OpStore %10 %18 - %19 = OpLoad %ulong %6 - %20 = OpLoad %ushort %7 - %43 = OpConvertUToPtr %_ptr_Generic_ushort %19 - %44 = OpCopyObject %ushort %20 - OpStore %43 %44 Aligned 2 - %21 = OpLoad %ulong %6 - %22 = OpLoad %ushort %8 - %45 = OpConvertUToPtr %_ptr_Generic_ushort %21 - %81 = OpBitcast %_ptr_Generic_uchar %45 - %82 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %81 %ulong_2_0 - %34 = OpBitcast %_ptr_Generic_ushort %82 - %46 = OpCopyObject %ushort %22 - OpStore %34 %46 Aligned 2 - %23 = OpLoad %ulong %6 - %24 = OpLoad %ushort %9 - %47 = OpConvertUToPtr %_ptr_Generic_ushort %23 - %83 = OpBitcast %_ptr_Generic_uchar %47 - %84 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %83 %ulong_4_0 - %36 = OpBitcast %_ptr_Generic_ushort %84 - %48 = OpCopyObject %ushort %24 - OpStore %36 %48 Aligned 2 - %25 = OpLoad %ulong %6 - %26 = OpLoad %ushort %10 - %49 = OpConvertUToPtr %_ptr_Generic_ushort %25 - %85 = OpBitcast %_ptr_Generic_uchar %49 - %86 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %85 %ulong_6_0 - %38 = OpBitcast %_ptr_Generic_ushort %86 - %50 = OpCopyObject %ushort %26 - OpStore %38 %50 Aligned 2 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt deleted file mode 100644 index b331ae6..0000000 --- a/ptx/src/test/spirv_run/constant_f32.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "constant_f32" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %float_0_5 = OpConstant %float 0.5 - %1 = OpFunction %void None %25 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %18 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpFMul %float %14 %float_0_5 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %19 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %19 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt deleted file mode 100644 index 9a5c7de..0000000 --- a/ptx/src/test/spirv_run/constant_negative.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "constant_negative" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint -%uint_4294967295 = OpConstant %uint 4294967295 - %1 = OpFunction %void None %25 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %18 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %13 = OpIMul %uint %14 %uint_4294967295 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %19 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %19 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cos.spvtxt b/ptx/src/test/spirv_run/cos.spvtxt deleted file mode 100644 index a79cdbe..0000000 --- a/ptx/src/test/spirv_run/cos.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cos" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 cos %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_f64_f32.spvtxt b/ptx/src/test/spirv_run/cvt_f64_f32.spvtxt deleted file mode 100644 index 907cce4..0000000 --- a/ptx/src/test/spirv_run/cvt_f64_f32.spvtxt +++ /dev/null @@ -1,55 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_f64_f32" - OpExecutionMode %1 DenormFlushToZero 16 - OpExecutionMode %1 DenormFlushToZero 32 - OpExecutionMode %1 DenormFlushToZero 64 - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %double = OpTypeFloat 64 -%_ptr_Function_double = OpTypePointer Function %double -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float -%_ptr_Generic_double = OpTypePointer Generic %double - %1 = OpFunction %void None %25 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_double Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13 - %12 = OpLoad %float %18 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %float %6 - %14 = OpFConvert %double %15 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %double %7 - %19 = OpConvertUToPtr %_ptr_Generic_double %16 - OpStore %19 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt deleted file mode 100644 index e10999c..0000000 --- a/ptx/src/test/spirv_run/cvt_rni.spvtxt +++ /dev/null @@ -1,69 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %34 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_rni" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %37 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %37 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %32 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %28 = OpConvertUToPtr %_ptr_Generic_float %13 - %12 = OpLoad %float %28 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_float %15 - %44 = OpBitcast %_ptr_Generic_uchar %29 - %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 - %25 = OpBitcast %_ptr_Generic_float %45 - %14 = OpLoad %float %25 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %float %6 - %16 = OpExtInst %float %34 rint %17 - OpStore %6 %16 - %19 = OpLoad %float %7 - %18 = OpExtInst %float %34 rint %19 - OpStore %7 %18 - %20 = OpLoad %ulong %5 - %21 = OpLoad %float %6 - %30 = OpConvertUToPtr %_ptr_Generic_float %20 - OpStore %30 %21 Aligned 4 - %22 = OpLoad %ulong %5 - %23 = OpLoad %float %7 - %31 = OpConvertUToPtr %_ptr_Generic_float %22 - %46 = OpBitcast %_ptr_Generic_uchar %31 - %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 - %27 = OpBitcast %_ptr_Generic_float %47 - OpStore %27 %23 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_rzi.spvtxt b/ptx/src/test/spirv_run/cvt_rzi.spvtxt deleted file mode 100644 index 7dda454..0000000 --- a/ptx/src/test/spirv_run/cvt_rzi.spvtxt +++ /dev/null @@ -1,69 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %34 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_rzi" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %37 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %37 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %32 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %28 = OpConvertUToPtr %_ptr_Generic_float %13 - %12 = OpLoad %float %28 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_float %15 - %44 = OpBitcast %_ptr_Generic_uchar %29 - %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 - %25 = OpBitcast %_ptr_Generic_float %45 - %14 = OpLoad %float %25 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %float %6 - %16 = OpExtInst %float %34 trunc %17 - OpStore %6 %16 - %19 = OpLoad %float %7 - %18 = OpExtInst %float %34 trunc %19 - OpStore %7 %18 - %20 = OpLoad %ulong %5 - %21 = OpLoad %float %6 - %30 = OpConvertUToPtr %_ptr_Generic_float %20 - OpStore %30 %21 Aligned 4 - %22 = OpLoad %ulong %5 - %23 = OpLoad %float %7 - %31 = OpConvertUToPtr %_ptr_Generic_float %22 - %46 = OpBitcast %_ptr_Generic_uchar %31 - %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 - %27 = OpBitcast %_ptr_Generic_float %47 - OpStore %27 %23 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt deleted file mode 100644 index 92322ec..0000000 --- a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_s16_s8" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %uchar = OpTypeInt 8 0 - %ushort = OpTypeInt 16 0 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %13 - %12 = OpLoad %uint %18 Aligned 4 - OpStore %7 %12 - %15 = OpLoad %uint %7 - %32 = OpBitcast %uint %15 - %34 = OpUConvert %uchar %32 - %20 = OpCopyObject %uchar %34 - %19 = OpSConvert %ushort %20 - %14 = OpSConvert %uint %19 - OpStore %6 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %6 - %21 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %21 %17 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt deleted file mode 100644 index c1229d4..0000000 --- a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt +++ /dev/null @@ -1,82 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %42 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_s32_f32" - OpDecorate %32 FPRoundingMode RTP - OpDecorate %34 FPRoundingMode RTP - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %45 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %float = OpTypeFloat 32 -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %ulong_4_0 = OpConstant %ulong 4 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %1 = OpFunction %void None %45 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %40 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_float %13 - %28 = OpLoad %float %29 Aligned 4 - %12 = OpBitcast %uint %28 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %30 = OpConvertUToPtr %_ptr_Generic_float %15 - %53 = OpBitcast %_ptr_Generic_uchar %30 - %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4 - %25 = OpBitcast %_ptr_Generic_float %54 - %31 = OpLoad %float %25 Aligned 4 - %14 = OpBitcast %uint %31 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %33 = OpBitcast %float %17 - %32 = OpConvertFToS %uint %33 - %16 = OpCopyObject %uint %32 - OpStore %6 %16 - %19 = OpLoad %uint %7 - %35 = OpBitcast %float %19 - %34 = OpConvertFToS %uint %35 - %18 = OpCopyObject %uint %34 - OpStore %7 %18 - %20 = OpLoad %ulong %5 - %21 = OpLoad %uint %6 - %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %20 - %37 = OpCopyObject %uint %21 - OpStore %36 %37 Aligned 4 - %22 = OpLoad %ulong %5 - %23 = OpLoad %uint %7 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %22 - %57 = OpBitcast %_ptr_CrossWorkgroup_uchar %38 - %58 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %57 %ulong_4_0 - %27 = OpBitcast %_ptr_CrossWorkgroup_uint %58 - %39 = OpCopyObject %uint %23 - OpStore %27 %39 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt deleted file mode 100644 index 1165290..0000000 --- a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt +++ /dev/null @@ -1,55 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_s64_s32" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_uint %13 - %18 = OpLoad %uint %19 Aligned 4 - %12 = OpCopyObject %uint %18 - OpStore %6 %12 - %15 = OpLoad %uint %6 - %14 = OpSConvert %ulong %15 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - %21 = OpCopyObject %ulong %17 - OpStore %20 %21 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt deleted file mode 100644 index 07b228e..0000000 --- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt +++ /dev/null @@ -1,56 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %25 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvt_sat_s_u" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %28 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %21 Aligned 4 - OpStore %6 %13 - %16 = OpLoad %uint %6 - %15 = OpSatConvertSToU %uint %16 - OpStore %7 %15 - %18 = OpLoad %uint %7 - %17 = OpCopyObject %uint %18 - OpStore %8 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %8 - %22 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %22 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt deleted file mode 100644 index e7a5655..0000000 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ /dev/null @@ -1,65 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %37 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvta" - %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %41 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %41 - %17 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %18 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %35 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %17 - OpStore %3 %18 - %10 = OpBitcast %_ptr_Function_ulong %2 - %9 = OpLoad %ulong %10 Aligned 8 - %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9 - OpStore %7 %19 - %12 = OpBitcast %_ptr_Function_ulong %3 - %11 = OpLoad %ulong %12 Aligned 8 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11 - OpStore %8 %20 - %21 = OpLoad %_ptr_CrossWorkgroup_uchar %7 - %14 = OpConvertPtrToU %ulong %21 - %30 = OpCopyObject %ulong %14 - %29 = OpCopyObject %ulong %30 - %13 = OpCopyObject %ulong %29 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 - OpStore %7 %22 - %23 = OpLoad %_ptr_CrossWorkgroup_uchar %8 - %16 = OpConvertPtrToU %ulong %23 - %32 = OpCopyObject %ulong %16 - %31 = OpCopyObject %ulong %32 - %15 = OpCopyObject %ulong %31 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 - OpStore %8 %24 - %26 = OpLoad %_ptr_CrossWorkgroup_uchar %7 - %33 = OpBitcast %_ptr_CrossWorkgroup_float %26 - %25 = OpLoad %float %33 Aligned 4 - OpStore %6 %25 - %27 = OpLoad %_ptr_CrossWorkgroup_uchar %8 - %28 = OpLoad %float %6 - %34 = OpBitcast %_ptr_CrossWorkgroup_float %27 - OpStore %34 %28 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt deleted file mode 100644 index 858ec8d..0000000 --- a/ptx/src/test/spirv_run/div_approx.spvtxt +++ /dev/null @@ -1,60 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "div_approx" - OpDecorate %16 FPFastMathMode AllowRecip - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_float %13 - %12 = OpLoad %float %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_float %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_float %39 - %14 = OpLoad %float %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %float %6 - %18 = OpLoad %float %7 - %16 = OpFDiv %float %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %float %6 - %25 = OpConvertUToPtr %_ptr_Generic_float %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ex2.spvtxt b/ptx/src/test/spirv_run/ex2.spvtxt deleted file mode 100644 index 29e5e86..0000000 --- a/ptx/src/test/spirv_run/ex2.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ex2" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 exp2 %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_func.spvtxt b/ptx/src/test/spirv_run/extern_func.spvtxt deleted file mode 100644 index b757029..0000000 --- a/ptx/src/test/spirv_run/extern_func.spvtxt +++ /dev/null @@ -1,75 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - %31 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %4 "extern_func" - OpExecutionMode %4 ContractionOff - OpDecorate %1 LinkageAttributes "foobar" Import - OpDecorate %12 Alignment 16 - OpDecorate %4 LinkageAttributes "extern_func" Export - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_16 = OpConstant %uint 16 -%_arr_uchar_uint_16 = OpTypeArray %uchar %uint_16 -%_ptr_Function__arr_uchar_uint_16 = OpTypePointer Function %_arr_uchar_uint_16 - %40 = OpTypeFunction %void %_ptr_Function_ulong %_ptr_Function__arr_uchar_uint_16 - %uint_16_0 = OpConstant %uint 16 - %42 = OpTypeFunction %void %ulong %ulong - %uint_16_1 = OpConstant %uint 16 -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_0 = OpConstant %ulong 0 -%_ptr_Function_uchar = OpTypePointer Function %uchar -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %40 - %3 = OpFunctionParameter %_ptr_Function_ulong - %2 = OpFunctionParameter %_ptr_Function__arr_uchar_uint_16 - OpFunctionEnd - %4 = OpFunction %void None %42 - %13 = OpFunctionParameter %ulong - %14 = OpFunctionParameter %ulong - %29 = OpLabel - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function__arr_uchar_uint_16 Function - OpStore %5 %13 - OpStore %6 %14 - %15 = OpLoad %ulong %5 Aligned 8 - OpStore %7 %15 - %16 = OpLoad %ulong %6 Aligned 8 - OpStore %8 %16 - %18 = OpLoad %ulong %7 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %18 - %17 = OpLoad %ulong %25 Aligned 8 - OpStore %9 %17 - %19 = OpLoad %ulong %9 - %46 = OpBitcast %_ptr_Function_uchar %11 - %47 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %46 %ulong_0 - %24 = OpBitcast %_ptr_Function_ulong %47 - %26 = OpCopyObject %ulong %19 - OpStore %24 %26 Aligned 8 - %48 = OpFunctionCall %void %1 %11 %12 - %27 = OpBitcast %_ptr_Function_ulong %12 - %20 = OpLoad %ulong %27 Aligned 8 - OpStore %10 %20 - %21 = OpLoad %ulong %8 - %22 = OpLoad %ulong %10 - %28 = OpConvertUToPtr %_ptr_Generic_ulong %21 - OpStore %28 %22 Aligned 8 - OpReturn - OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt deleted file mode 100644 index 025cd81..0000000 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ /dev/null @@ -1,56 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "extern_shared" %1 - OpExecutionMode %2 ContractionOff - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup - %ulong = OpTypeInt 64 0 - %29 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %29 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %3 %8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %11 = OpLoad %ulong %4 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %12 = OpLoad %ulong %18 Aligned 8 - OpStore %7 %12 - %14 = OpLoad %ulong %7 - %19 = OpBitcast %_ptr_Workgroup_ulong %1 - OpStore %19 %14 Aligned 8 - %20 = OpBitcast %_ptr_Workgroup_ulong %1 - %15 = OpLoad %ulong %20 Aligned 8 - OpStore %7 %15 - %16 = OpLoad %ulong %6 - %17 = OpLoad %ulong %7 - %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - OpStore %21 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt deleted file mode 100644 index bf1dccd..0000000 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ /dev/null @@ -1,75 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %35 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %12 "extern_shared_call" %1 - OpExecutionMode %12 ContractionOff - OpDecorate %1 Alignment 4 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup - %39 = OpTypeFunction %void %_ptr_Workgroup_uint - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %ulong_2 = OpConstant %ulong 2 - %43 = OpTypeFunction %void %ulong %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %39 - %34 = OpFunctionParameter %_ptr_Workgroup_uint - %11 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %9 = OpBitcast %_ptr_Workgroup_ulong %34 - %4 = OpLoad %ulong %9 Aligned 8 - OpStore %3 %4 - %6 = OpLoad %ulong %3 - %5 = OpIAdd %ulong %6 %ulong_2 - OpStore %3 %5 - %7 = OpLoad %ulong %3 - %10 = OpBitcast %_ptr_Workgroup_ulong %34 - OpStore %10 %7 Aligned 8 - OpReturn - OpFunctionEnd - %12 = OpFunction %void None %43 - %18 = OpFunctionParameter %ulong - %19 = OpFunctionParameter %ulong - %32 = OpLabel - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - %16 = OpVariable %_ptr_Function_ulong Function - %17 = OpVariable %_ptr_Function_ulong Function - OpStore %13 %18 - OpStore %14 %19 - %20 = OpLoad %ulong %13 Aligned 8 - OpStore %15 %20 - %21 = OpLoad %ulong %14 Aligned 8 - OpStore %16 %21 - %23 = OpLoad %ulong %15 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 - %22 = OpLoad %ulong %28 Aligned 8 - OpStore %17 %22 - %24 = OpLoad %ulong %17 - %29 = OpBitcast %_ptr_Workgroup_ulong %1 - OpStore %29 %24 Aligned 8 - %45 = OpFunctionCall %void %2 %1 - %30 = OpBitcast %_ptr_Workgroup_ulong %1 - %25 = OpLoad %ulong %30 Aligned 8 - OpStore %17 %25 - %26 = OpLoad %ulong %16 - %27 = OpLoad %ulong %17 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 - OpStore %31 %27 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt deleted file mode 100644 index 91a2159..0000000 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ /dev/null @@ -1,69 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %35 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "fma" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %38 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %1 = OpFunction %void None %38 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %33 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_float %14 - %13 = OpLoad %float %29 Aligned 4 - OpStore %6 %13 - %16 = OpLoad %ulong %4 - %30 = OpConvertUToPtr %_ptr_Generic_float %16 - %45 = OpBitcast %_ptr_Generic_uchar %30 - %46 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %45 %ulong_4 - %26 = OpBitcast %_ptr_Generic_float %46 - %15 = OpLoad %float %26 Aligned 4 - OpStore %7 %15 - %18 = OpLoad %ulong %4 - %31 = OpConvertUToPtr %_ptr_Generic_float %18 - %47 = OpBitcast %_ptr_Generic_uchar %31 - %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_8 - %28 = OpBitcast %_ptr_Generic_float %48 - %17 = OpLoad %float %28 Aligned 4 - OpStore %8 %17 - %20 = OpLoad %float %6 - %21 = OpLoad %float %7 - %22 = OpLoad %float %8 - %19 = OpExtInst %float %35 fma %20 %21 %22 - OpStore %6 %19 - %23 = OpLoad %ulong %5 - %24 = OpLoad %float %6 - %32 = OpConvertUToPtr %_ptr_Generic_float %23 - OpStore %32 %24 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/func_ptr.spvtxt b/ptx/src/test/spirv_run/func_ptr.spvtxt deleted file mode 100644 index 4ff74c6..0000000 --- a/ptx/src/test/spirv_run/func_ptr.spvtxt +++ /dev/null @@ -1,77 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - %39 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %12 "func_ptr" - OpExecutionMode %12 ContractionOff - OpDecorate %12 LinkageAttributes "func_ptr" Export - %void = OpTypeVoid - %float = OpTypeFloat 32 - %42 = OpTypeFunction %float %float %float -%_ptr_Function_float = OpTypePointer Function %float - %ulong = OpTypeInt 64 0 - %45 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %ulong_0 = OpConstant %ulong 0 - %1 = OpFunction %float None %42 - %5 = OpFunctionParameter %float - %6 = OpFunctionParameter %float - %11 = OpLabel - %3 = OpVariable %_ptr_Function_float Function - %4 = OpVariable %_ptr_Function_float Function - %2 = OpVariable %_ptr_Function_float Function - OpStore %3 %5 - OpStore %4 %6 - %8 = OpLoad %float %3 - %9 = OpLoad %float %4 - %7 = OpFAdd %float %8 %9 - OpStore %2 %7 - %10 = OpLoad %float %2 - OpReturnValue %10 - OpFunctionEnd - %12 = OpFunction %void None %45 - %20 = OpFunctionParameter %ulong - %21 = OpFunctionParameter %ulong - %37 = OpLabel - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - %16 = OpVariable %_ptr_Function_ulong Function - %17 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_ulong Function - OpStore %13 %20 - OpStore %14 %21 - %22 = OpLoad %ulong %13 Aligned 8 - OpStore %15 %22 - %23 = OpLoad %ulong %14 Aligned 8 - OpStore %16 %23 - %25 = OpLoad %ulong %15 - %35 = OpConvertUToPtr %_ptr_Generic_ulong %25 - %24 = OpLoad %ulong %35 Aligned 8 - OpStore %17 %24 - %27 = OpLoad %ulong %17 - %26 = OpIAdd %ulong %27 %ulong_1 - OpStore %18 %26 - %28 = OpCopyObject %ulong %ulong_0 - OpStore %19 %28 - %30 = OpLoad %ulong %18 - %31 = OpLoad %ulong %19 - %29 = OpIAdd %ulong %30 %31 - OpStore %18 %29 - %32 = OpLoad %ulong %16 - %33 = OpLoad %ulong %18 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %32 - OpStore %36 %33 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/global_array.spvtxt b/ptx/src/test/spirv_run/global_array.spvtxt deleted file mode 100644 index 4eccb2f..0000000 --- a/ptx/src/test/spirv_run/global_array.spvtxt +++ /dev/null @@ -1,53 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "global_array" %1 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uint_4 = OpConstant %uint 4 -%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 - %uint_1 = OpConstant %uint 1 - %uint_0 = OpConstant %uint 0 - %28 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0 - %uint_4_0 = OpConstant %uint 4 -%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4 - %1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %28 - %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %2 = OpFunction %void None %32 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %19 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %3 %8 - OpStore %4 %9 - %16 = OpConvertPtrToU %ulong %1 - %10 = OpCopyObject %ulong %16 - OpStore %5 %10 - %11 = OpLoad %ulong %4 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %13 - %12 = OpLoad %uint %17 Aligned 4 - OpStore %7 %12 - %14 = OpLoad %ulong %6 - %15 = OpLoad %uint %7 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14 - OpStore %18 %15 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/implicit_param.spvtxt b/ptx/src/test/spirv_run/implicit_param.spvtxt deleted file mode 100644 index 760761a..0000000 --- a/ptx/src/test/spirv_run/implicit_param.spvtxt +++ /dev/null @@ -1,53 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "implicit_param" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13 - %12 = OpLoad %float %18 Aligned 4 - OpStore %6 %12 - %14 = OpLoad %float %6 - %19 = OpBitcast %_ptr_Function_float %7 - OpStore %19 %14 Aligned 4 - %20 = OpBitcast %_ptr_Function_float %7 - %15 = OpLoad %float %20 Aligned 4 - OpStore %6 %15 - %16 = OpLoad %ulong %5 - %17 = OpLoad %float %6 - %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 - OpStore %21 %17 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/lanemask_lt.spvtxt b/ptx/src/test/spirv_run/lanemask_lt.spvtxt deleted file mode 100644 index 3de53ce..0000000 --- a/ptx/src/test/spirv_run/lanemask_lt.spvtxt +++ /dev/null @@ -1,70 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %40 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "lanemask_lt" - OpExecutionMode %1 ContractionOff - OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_lanemask_lt" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %43 = OpTypeFunction %uint - %ulong = OpTypeInt 64 0 - %45 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %uint_1 = OpConstant %uint 1 - %11 = OpFunction %uint None %43 - OpFunctionEnd - %1 = OpFunction %void None %45 - %13 = OpFunctionParameter %ulong - %14 = OpFunctionParameter %ulong - %38 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %13 - OpStore %3 %14 - %15 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %15 - %16 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %16 - %18 = OpLoad %ulong %4 - %29 = OpConvertUToPtr %_ptr_Generic_uint %18 - %28 = OpLoad %uint %29 Aligned 4 - %17 = OpCopyObject %uint %28 - OpStore %6 %17 - %20 = OpLoad %uint %6 - %31 = OpCopyObject %uint %20 - %30 = OpIAdd %uint %31 %uint_1 - %19 = OpCopyObject %uint %30 - OpStore %7 %19 - %10 = OpFunctionCall %uint %11 - %32 = OpCopyObject %uint %10 - %21 = OpCopyObject %uint %32 - OpStore %8 %21 - %23 = OpLoad %uint %7 - %24 = OpLoad %uint %8 - %34 = OpCopyObject %uint %23 - %35 = OpCopyObject %uint %24 - %33 = OpIAdd %uint %34 %35 - %22 = OpCopyObject %uint %33 - OpStore %7 %22 - %25 = OpLoad %ulong %5 - %26 = OpLoad %uint %7 - %36 = OpConvertUToPtr %_ptr_Generic_uint %25 - %37 = OpCopyObject %uint %26 - OpStore %36 %37 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st.spvtxt b/ptx/src/test/spirv_run/ld_st.spvtxt deleted file mode 100644 index 447b1aa..0000000 --- a/ptx/src/test/spirv_run/ld_st.spvtxt +++ /dev/null @@ -1,42 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %19 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %22 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %22 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %17 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %15 = OpConvertUToPtr %_ptr_Generic_ulong %12 - %11 = OpLoad %ulong %15 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %14 = OpLoad %ulong %6 - %16 = OpConvertUToPtr %_ptr_Generic_ulong %13 - OpStore %16 %14 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt deleted file mode 100644 index 9c0e508..0000000 --- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt +++ /dev/null @@ -1,56 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st_implicit" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%ulong_81985529216486895 = OpConstant %ulong 81985529216486895 - %float = OpTypeFloat 32 -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %uint = OpTypeInt 32 0 - %1 = OpFunction %void None %26 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %11 = OpCopyObject %ulong %ulong_81985529216486895 - OpStore %6 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13 - %17 = OpLoad %float %18 Aligned 4 - %31 = OpBitcast %uint %17 - %12 = OpUConvert %ulong %31 - OpStore %6 %12 - %14 = OpLoad %ulong %5 - %15 = OpLoad %ulong %6 - %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 - %32 = OpBitcast %ulong %15 - %33 = OpUConvert %uint %32 - %20 = OpBitcast %float %33 - OpStore %19 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt deleted file mode 100644 index ea97222..0000000 --- a/ptx/src/test/spirv_run/ld_st_offset.spvtxt +++ /dev/null @@ -1,63 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st_offset" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %33 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %24 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %15 - %40 = OpBitcast %_ptr_Generic_uchar %25 - %41 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %40 %ulong_4 - %21 = OpBitcast %_ptr_Generic_uint %41 - %14 = OpLoad %uint %21 Aligned 4 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %7 - %26 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %26 %17 Aligned 4 - %18 = OpLoad %ulong %5 - %19 = OpLoad %uint %6 - %27 = OpConvertUToPtr %_ptr_Generic_uint %18 - %42 = OpBitcast %_ptr_Generic_uchar %27 - %43 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %42 %ulong_4_0 - %23 = OpBitcast %_ptr_Generic_uint %43 - OpStore %23 %19 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/lg2.spvtxt b/ptx/src/test/spirv_run/lg2.spvtxt deleted file mode 100644 index a8175cf..0000000 --- a/ptx/src/test/spirv_run/lg2.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "lg2" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 log2 %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt deleted file mode 100644 index a2cfd4c..0000000 --- a/ptx/src/test/spirv_run/local_align.spvtxt +++ /dev/null @@ -1,49 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %20 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "local_align" - OpDecorate %4 Alignment 8 - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %23 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_8 = OpConstant %uint 8 -%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8 -%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %23 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %18 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %16 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %16 Aligned 8 - OpStore %7 %12 - %14 = OpLoad %ulong %6 - %15 = OpLoad %ulong %7 - %17 = OpConvertUToPtr %_ptr_Generic_ulong %14 - OpStore %17 %15 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt deleted file mode 100644 index 0ee3ca7..0000000 --- a/ptx/src/test/spirv_run/mad_s32.spvtxt +++ /dev/null @@ -1,87 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %46 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mad_s32" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %49 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %ulong_4_0 = OpConstant %ulong 4 - %ulong_8_0 = OpConstant %ulong 8 - %1 = OpFunction %void None %49 - %10 = OpFunctionParameter %ulong - %11 = OpFunctionParameter %ulong - %44 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_uint Function - %9 = OpVariable %_ptr_Function_uint Function - OpStore %2 %10 - OpStore %3 %11 - %12 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %12 - %13 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %13 - %15 = OpLoad %ulong %4 - %38 = OpConvertUToPtr %_ptr_Generic_uint %15 - %14 = OpLoad %uint %38 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %ulong %4 - %39 = OpConvertUToPtr %_ptr_Generic_uint %17 - %56 = OpBitcast %_ptr_Generic_uchar %39 - %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4 - %31 = OpBitcast %_ptr_Generic_uint %57 - %16 = OpLoad %uint %31 Aligned 4 - OpStore %8 %16 - %19 = OpLoad %ulong %4 - %40 = OpConvertUToPtr %_ptr_Generic_uint %19 - %58 = OpBitcast %_ptr_Generic_uchar %40 - %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_8 - %33 = OpBitcast %_ptr_Generic_uint %59 - %18 = OpLoad %uint %33 Aligned 4 - OpStore %9 %18 - %21 = OpLoad %uint %7 - %22 = OpLoad %uint %8 - %23 = OpLoad %uint %9 - %60 = OpIMul %uint %21 %22 - %20 = OpIAdd %uint %23 %60 - OpStore %6 %20 - %24 = OpLoad %ulong %5 - %25 = OpLoad %uint %6 - %41 = OpConvertUToPtr %_ptr_Generic_uint %24 - OpStore %41 %25 Aligned 4 - %26 = OpLoad %ulong %5 - %27 = OpLoad %uint %6 - %42 = OpConvertUToPtr %_ptr_Generic_uint %26 - %61 = OpBitcast %_ptr_Generic_uchar %42 - %62 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %61 %ulong_4_0 - %35 = OpBitcast %_ptr_Generic_uint %62 - OpStore %35 %27 Aligned 4 - %28 = OpLoad %ulong %5 - %29 = OpLoad %uint %6 - %43 = OpConvertUToPtr %_ptr_Generic_uint %28 - %63 = OpBitcast %_ptr_Generic_uchar %43 - %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_8_0 - %37 = OpBitcast %_ptr_Generic_uint %64 - OpStore %37 %29 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt deleted file mode 100644 index 86b732a..0000000 --- a/ptx/src/test/spirv_run/max.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "max" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %39 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %16 = OpExtInst %uint %28 s_max %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/membar.spvtxt b/ptx/src/test/spirv_run/membar.spvtxt deleted file mode 100644 index d808cf3..0000000 --- a/ptx/src/test/spirv_run/membar.spvtxt +++ /dev/null @@ -1,49 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %20 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "membar" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %23 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %uint_0 = OpConstant %uint 0 - %uint_784 = OpConstant %uint 784 - %1 = OpFunction %void None %23 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %18 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %16 = OpConvertUToPtr %_ptr_Generic_uint %12 - %15 = OpLoad %uint %16 Aligned 4 - %11 = OpCopyObject %uint %15 - OpStore %6 %11 - OpMemoryBarrier %uint_0 %uint_784 - %13 = OpLoad %ulong %5 - %14 = OpLoad %uint %6 - %17 = OpConvertUToPtr %_ptr_Generic_uint %13 - OpStore %17 %14 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt deleted file mode 100644 index a187376..0000000 --- a/ptx/src/test/spirv_run/min.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "min" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %39 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %16 = OpExtInst %uint %28 s_min %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e15d6ea..f4b7921 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,31 +1,11 @@ use crate::pass; -use crate::ptx; -use crate::translate; use hip_runtime_sys::hipError_t; -use rspirv::{ - binary::{Assemble, Disassemble}, - dr::{Block, Function, Instruction, Loader, Operand}, -}; -use spirv_headers::Word; -use spirv_tools_sys::{ - spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env, -}; -use std::collections::hash_map::Entry; use std::error; -use std::ffi::{c_void, CStr, CString}; +use std::ffi::{CStr, CString}; use std::fmt; use std::fmt::{Debug, Display, Formatter}; -use std::fs::File; -use std::hash::Hash; -use std::io; -use std::io::Read; -use std::io::Write; use std::mem; -use std::path::Path; -use std::process::Command; -use std::slice; -use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str}; -use tempfile::NamedTempFile; +use std::{ptr, str}; macro_rules! test_ptx { ($fn_name:ident, $input:expr, $output:expr) => { @@ -65,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); test_ptx!(shl, [11u64], [44u64]); -test_ptx!(shl_link_hack, [11u64], [44u64]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); @@ -236,7 +215,7 @@ fn test_hip_assert< output: &mut [Output], ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module2(ast).unwrap(); + let llvm_ir = pass::to_llvm_module(ast).unwrap(); let name = CString::new(name)?; let result = run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?; @@ -326,6 +305,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, &*module.llvm_ir, + module.linked_bitcode(), ) .unwrap(); let mut module = ptr::null_mut(); @@ -381,226 +361,3 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def } Ok(result) } - -struct EqMap -where - T: Eq + Copy + Hash, -{ - m1: HashMap, - m2: HashMap, -} - -impl EqMap { - fn new() -> Self { - EqMap { - m1: HashMap::new(), - m2: HashMap::new(), - } - } - - fn is_equal(&mut self, t1: T, t2: T) -> bool { - match (self.m1.entry(t1), self.m2.entry(t2)) { - (Entry::Occupied(entry1), Entry::Occupied(entry2)) => { - *entry1.get() == t2 && *entry2.get() == t1 - } - (Entry::Vacant(entry1), Entry::Vacant(entry2)) => { - entry1.insert(t2); - entry2.insert(t1); - true - } - _ => false, - } - } -} - -fn is_spirv_fns_equal(fns1: &[Function], fns2: &[Function]) -> bool { - if fns1.len() != fns2.len() { - return false; - } - for (fn1, fn2) in fns1.iter().zip(fns2.iter()) { - if !is_spirv_fn_equal(fn1, fn2) { - return false; - } - } - true -} - -fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { - let mut map = EqMap::new(); - if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) { - return false; - } - if !is_option_equal(&fn1.end, &fn2.end, &mut map, is_instr_equal) { - return false; - } - if fn1.parameters.len() != fn2.parameters.len() { - return false; - } - for (inst1, inst2) in fn1.parameters.iter().zip(fn2.parameters.iter()) { - if !is_instr_equal(inst1, inst2, &mut map) { - return false; - } - } - if fn1.blocks.len() != fn2.blocks.len() { - return false; - } - for (b1, b2) in fn1.blocks.iter().zip(fn2.blocks.iter()) { - if !is_block_equal(b1, b2, &mut map) { - return false; - } - } - true -} - -fn is_block_equal(b1: &Block, b2: &Block, map: &mut EqMap) -> bool { - if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) { - return false; - } - if b1.instructions.len() != b2.instructions.len() { - return false; - } - for (inst1, inst2) in b1.instructions.iter().zip(b2.instructions.iter()) { - if !is_instr_equal(inst1, inst2, map) { - return false; - } - } - true -} - -fn is_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut EqMap) -> bool { - if instr1.class.opcode != instr2.class.opcode { - return false; - } - if !is_option_equal(&instr1.result_type, &instr2.result_type, map, is_word_equal) { - return false; - } - if !is_option_equal(&instr1.result_id, &instr2.result_id, map, is_word_equal) { - return false; - } - if instr1.operands.len() != instr2.operands.len() { - return false; - } - for (o1, o2) in instr1.operands.iter().zip(instr2.operands.iter()) { - match (o1, o2) { - (Operand::IdMemorySemantics(w1), Operand::IdMemorySemantics(w2)) => { - if !is_word_equal(w1, w2, map) { - return false; - } - } - (Operand::IdScope(w1), Operand::IdScope(w2)) => { - if !is_word_equal(w1, w2, map) { - return false; - } - } - (Operand::IdRef(w1), Operand::IdRef(w2)) => { - if !is_word_equal(w1, w2, map) { - return false; - } - } - (o1, o2) => { - if o1 != o2 { - return false; - } - } - } - } - true -} - -fn is_word_equal(t1: &Word, t2: &Word, map: &mut EqMap) -> bool { - map.is_equal(*t1, *t2) -} - -fn is_option_equal) -> bool>( - o1: &Option, - o2: &Option, - map: &mut EqMap, - f: F, -) -> bool { - match (o1, o2) { - (Some(t1), Some(t2)) => f(t1, t2, map), - (None, None) => true, - _ => panic!(), - } -} - -unsafe extern "C" fn parse_header_cb( - user_data: *mut c_void, - endian: spv_endianness_t, - magic: u32, - version: u32, - generator: u32, - id_bound: u32, - reserved: u32, -) -> spv_result_t { - if endian == spv_endianness_t::SPV_ENDIANNESS_BIG { - return spv_result_t::SPV_UNSUPPORTED; - } - let result_vec: &mut Vec = std::mem::transmute(user_data); - result_vec.push(magic); - result_vec.push(version); - result_vec.push(generator); - result_vec.push(id_bound); - result_vec.push(reserved); - spv_result_t::SPV_SUCCESS -} - -unsafe extern "C" fn parse_instruction_cb( - user_data: *mut c_void, - inst: *const spv_parsed_instruction_t, -) -> spv_result_t { - let inst = &*inst; - let result_vec: &mut Vec = std::mem::transmute(user_data); - for i in 0..inst.num_words { - result_vec.push(*(inst.words.add(i as usize))); - } - spv_result_t::SPV_SUCCESS -} - -const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv"; -const AMDGPU: &'static str = "/opt/rocm/"; -const AMDGPU_TARGET: &'static str = "amdgcn-amd-amdhsa"; -const AMDGPU_BITCODE: [&'static str; 8] = [ - "opencl.bc", - "ocml.bc", - "ockl.bc", - "oclc_correctly_rounded_sqrt_off.bc", - "oclc_daz_opt_on.bc", - "oclc_finite_only_off.bc", - "oclc_unsafe_math_off.bc", - "oclc_wavefrontsize64_off.bc", -]; -const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_"; - -fn persist_file(path: &Path) -> io::Result<()> { - let mut persistent = PathBuf::from("/tmp/zluda"); - std::fs::create_dir_all(&persistent)?; - persistent.push(path.file_name().unwrap()); - std::fs::copy(path, persistent)?; - Ok(()) -} - -fn get_bitcode_paths(device_name: &str) -> impl Iterator { - let generic_paths = AMDGPU_BITCODE.iter().map(|x| { - let mut path = PathBuf::from(AMDGPU); - path.push("amdgcn"); - path.push("bitcode"); - path.push(x); - path - }); - let suffix = if let Some(suffix_idx) = device_name.find(':') { - suffix_idx - } else { - device_name.len() - }; - let mut additional_path = PathBuf::from(AMDGPU); - additional_path.push("amdgcn"); - additional_path.push("bitcode"); - additional_path.push(format!( - "{}{}{}", - AMDGPU_BITCODE_DEVICE_PREFIX, - &device_name[3..suffix], - ".bc" - )); - generic_paths.chain(std::iter::once(additional_path)) -} diff --git a/ptx/src/test/spirv_run/mov.spvtxt b/ptx/src/test/spirv_run/mov.spvtxt deleted file mode 100644 index 13473d9..0000000 --- a/ptx/src/test/spirv_run/mov.spvtxt +++ /dev/null @@ -1,46 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mov" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %25 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %18 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpCopyObject %ulong %15 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %19 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mov_address.spvtxt b/ptx/src/test/spirv_run/mov_address.spvtxt deleted file mode 100644 index 26ae21f..0000000 --- a/ptx/src/test/spirv_run/mov_address.spvtxt +++ /dev/null @@ -1,33 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %12 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mov_address" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %15 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uchar = OpTypeInt 8 0 - %uint = OpTypeInt 32 0 - %uint_8 = OpConstant %uint 8 -%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8 -%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 - %1 = OpFunction %void None %15 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %10 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function - %5 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %6 - OpStore %3 %7 - %9 = OpConvertPtrToU %ulong %4 - %8 = OpCopyObject %ulong %9 - OpStore %5 %8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt deleted file mode 100644 index e7a4a56..0000000 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mul_ftz" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_float %13 - %12 = OpLoad %float %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_float %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_float %39 - %14 = OpLoad %float %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %float %6 - %18 = OpLoad %float %7 - %16 = OpFMul %float %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %float %6 - %25 = OpConvertUToPtr %_ptr_Generic_float %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_hi.spvtxt b/ptx/src/test/spirv_run/mul_hi.spvtxt deleted file mode 100644 index 93537b3..0000000 --- a/ptx/src/test/spirv_run/mul_hi.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mul_hi" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_2 = OpConstant %ulong 2 - %1 = OpFunction %void None %26 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpExtInst %ulong %23 u_mul_hi %15 %ulong_2 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_lo.spvtxt b/ptx/src/test/spirv_run/mul_lo.spvtxt deleted file mode 100644 index 7d69cfb..0000000 --- a/ptx/src/test/spirv_run/mul_lo.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mul_lo" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_2 = OpConstant %ulong 2 - %1 = OpFunction %void None %26 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpIMul %ulong %15 %ulong_2 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt deleted file mode 100644 index 5326baa..0000000 --- a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mul_non_ftz" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_float %13 - %12 = OpLoad %float %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_float %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_float %39 - %14 = OpLoad %float %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %float %6 - %18 = OpLoad %float %7 - %16 = OpFMul %float %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %float %6 - %25 = OpConvertUToPtr %_ptr_Generic_float %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt deleted file mode 100644 index b8ffac0..0000000 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ /dev/null @@ -1,66 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "mul_wide" - OpExecutionMode %1 ContractionOff - OpDecorate %17 NoSignedWrap - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %33 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14 - %13 = OpLoad %uint %24 Aligned 4 - OpStore %6 %13 - %16 = OpLoad %ulong %4 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %25 - %41 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %40 %ulong_4 - %23 = OpBitcast %_ptr_CrossWorkgroup_uint %41 - %15 = OpLoad %uint %23 Aligned 4 - OpStore %7 %15 - %18 = OpLoad %uint %6 - %19 = OpLoad %uint %7 - %42 = OpSConvert %ulong %18 - %43 = OpSConvert %ulong %19 - %17 = OpIMul %ulong %42 %43 - OpStore %8 %17 - %20 = OpLoad %ulong %5 - %21 = OpLoad %ulong %8 - %26 = OpConvertUToPtr %_ptr_Generic_ulong %20 - %27 = OpCopyObject %ulong %21 - OpStore %26 %27 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/neg.spvtxt b/ptx/src/test/spirv_run/neg.spvtxt deleted file mode 100644 index d5ab925..0000000 --- a/ptx/src/test/spirv_run/neg.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "neg" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %13 = OpSNegate %uint %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt deleted file mode 100644 index 92dc7cc..0000000 --- a/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt +++ /dev/null @@ -1,60 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %27 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "non_scalar_ptr_offset" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %30 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %ulong_8 = OpConstant %ulong 8 - %v2uint = OpTypeVector %uint 2 -%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %1 = OpFunction %void None %30 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %25 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13 - %38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23 - %39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8 - %22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39 - %8 = OpLoad %v2uint %22 Aligned 8 - %14 = OpCompositeExtract %uint %8 0 - %15 = OpCompositeExtract %uint %8 1 - OpStore %6 %14 - OpStore %7 %15 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %16 = OpIAdd %uint %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19 - OpStore %24 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt deleted file mode 100644 index 655a892..0000000 --- a/ptx/src/test/spirv_run/not.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "not" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %18 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %20 = OpCopyObject %ulong %15 - %19 = OpNot %ulong %20 - %14 = OpCopyObject %ulong %19 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %21 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt deleted file mode 100644 index 6754ce4..0000000 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ /dev/null @@ -1,60 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ntid" - OpExecutionMode %1 ContractionOff - OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_ntid" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %34 = OpTypeFunction %uint %uchar - %ulong = OpTypeInt 64 0 - %36 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %uchar_0 = OpConstant %uchar 0 - %11 = OpFunction %uint None %34 - %13 = OpFunctionParameter %uchar - OpFunctionEnd - %1 = OpFunction %void None %36 - %14 = OpFunctionParameter %ulong - %15 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %14 - OpStore %3 %15 - %16 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %16 - %17 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %17 - %19 = OpLoad %ulong %4 - %26 = OpConvertUToPtr %_ptr_Generic_uint %19 - %18 = OpLoad %uint %26 Aligned 4 - OpStore %6 %18 - %10 = OpFunctionCall %uint %11 %uchar_0 - %20 = OpCopyObject %uint %10 - OpStore %7 %20 - %22 = OpLoad %uint %6 - %23 = OpLoad %uint %7 - %21 = OpIAdd %uint %22 %23 - OpStore %6 %21 - %24 = OpLoad %ulong %5 - %25 = OpLoad %uint %6 - %27 = OpConvertUToPtr %_ptr_Generic_uint %24 - OpStore %27 %25 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt deleted file mode 100644 index 82db00c..0000000 --- a/ptx/src/test/spirv_run/or.spvtxt +++ /dev/null @@ -1,60 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %31 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "or" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %34 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_8 = OpConstant %ulong 8 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %34 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %29 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %23 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %39 = OpBitcast %_ptr_Generic_uchar %24 - %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_8 - %22 = OpBitcast %_ptr_Generic_ulong %40 - %14 = OpLoad %ulong %22 Aligned 8 - OpStore %7 %14 - %17 = OpLoad %ulong %6 - %18 = OpLoad %ulong %7 - %26 = OpCopyObject %ulong %17 - %27 = OpCopyObject %ulong %18 - %25 = OpBitwiseOr %ulong %26 %27 - %16 = OpCopyObject %ulong %25 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %ulong %6 - %28 = OpConvertUToPtr %_ptr_Generic_ulong %19 - OpStore %28 %20 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt deleted file mode 100644 index c41e792..0000000 --- a/ptx/src/test/spirv_run/popc.spvtxt +++ /dev/null @@ -1,52 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "popc" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %25 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %18 = OpBitCount %uint %14 - %13 = OpCopyObject %uint %18 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %19 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %19 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt deleted file mode 100644 index 644731b..0000000 --- a/ptx/src/test/spirv_run/pred_not.spvtxt +++ /dev/null @@ -1,82 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %42 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "pred_not" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %45 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_8 = OpConstant %ulong 8 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %true = OpConstantTrue %bool - %false = OpConstantFalse %bool - %ulong_1 = OpConstant %ulong 1 - %ulong_2 = OpConstant %ulong 2 - %1 = OpFunction %void None %45 - %14 = OpFunctionParameter %ulong - %15 = OpFunctionParameter %ulong - %40 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_bool Function - OpStore %2 %14 - OpStore %3 %15 - %16 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %16 - %17 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %17 - %19 = OpLoad %ulong %4 - %37 = OpConvertUToPtr %_ptr_Generic_ulong %19 - %18 = OpLoad %ulong %37 Aligned 8 - OpStore %6 %18 - %21 = OpLoad %ulong %4 - %38 = OpConvertUToPtr %_ptr_Generic_ulong %21 - %52 = OpBitcast %_ptr_Generic_uchar %38 - %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_8 - %34 = OpBitcast %_ptr_Generic_ulong %53 - %20 = OpLoad %ulong %34 Aligned 8 - OpStore %7 %20 - %23 = OpLoad %ulong %6 - %24 = OpLoad %ulong %7 - %22 = OpULessThan %bool %23 %24 - OpStore %9 %22 - %26 = OpLoad %bool %9 - %25 = OpSelect %bool %26 %false %true - OpStore %9 %25 - %27 = OpLoad %bool %9 - OpBranchConditional %27 %10 %11 - %10 = OpLabel - %28 = OpCopyObject %ulong %ulong_1 - OpStore %8 %28 - OpBranch %11 - %11 = OpLabel - %29 = OpLoad %bool %9 - OpBranchConditional %29 %13 %12 - %12 = OpLabel - %30 = OpCopyObject %ulong %ulong_2 - OpStore %8 %30 - OpBranch %13 - %13 = OpLabel - %31 = OpLoad %ulong %5 - %32 = OpLoad %ulong %8 - %39 = OpConvertUToPtr %_ptr_Generic_ulong %31 - OpStore %39 %32 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/prmt.spvtxt b/ptx/src/test/spirv_run/prmt.spvtxt deleted file mode 100644 index 060f534..0000000 --- a/ptx/src/test/spirv_run/prmt.spvtxt +++ /dev/null @@ -1,67 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %31 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "prmt" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %34 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %v4uchar = OpTypeVector %uchar 4 - %1 = OpFunction %void None %34 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %29 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %41 = OpBitcast %_ptr_Generic_uchar %24 - %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %42 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %26 = OpCopyObject %uint %17 - %27 = OpCopyObject %uint %18 - %44 = OpBitcast %v4uchar %26 - %45 = OpBitcast %v4uchar %27 - %46 = OpVectorShuffle %v4uchar %44 %45 4 0 6 7 - %25 = OpBitcast %uint %46 - %16 = OpCopyObject %uint %25 - OpStore %7 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %7 - %28 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %28 %20 Aligned 4 - OpReturn - OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt deleted file mode 100644 index 09fa0d9..0000000 --- a/ptx/src/test/spirv_run/rcp.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "rcp" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 native_recip %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt deleted file mode 100644 index ddb6a9e..0000000 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ /dev/null @@ -1,76 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - %34 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "reg_local" - OpExecutionMode %1 ContractionOff - OpDecorate %4 Alignment 8 - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %37 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_8 = OpConstant %uint 8 -%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8 -%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_1 = OpConstant %ulong 1 -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_0 = OpConstant %ulong 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_0_0 = OpConstant %ulong 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %1 = OpFunction %void None %37 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %32 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %24 = OpLoad %ulong %25 Aligned 8 - %12 = OpCopyObject %ulong %24 - OpStore %7 %12 - %14 = OpLoad %ulong %7 - %19 = OpIAdd %ulong %14 %ulong_1 - %46 = OpBitcast %_ptr_Function_ulong %4 - %26 = OpPtrCastToGeneric %_ptr_Generic_ulong %46 - %27 = OpCopyObject %ulong %19 - OpStore %26 %27 Aligned 8 - %47 = OpBitcast %_ptr_Function_ulong %4 - %28 = OpPtrCastToGeneric %_ptr_Generic_ulong %47 - %49 = OpBitcast %_ptr_Generic_uchar %28 - %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_0 - %21 = OpBitcast %_ptr_Generic_ulong %50 - %29 = OpLoad %ulong %21 Aligned 8 - %15 = OpCopyObject %ulong %29 - OpStore %7 %15 - %16 = OpLoad %ulong %6 - %17 = OpLoad %ulong %7 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - %52 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 - %53 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %52 %ulong_0_0 - %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %53 - %31 = OpCopyObject %ulong %17 - OpStore %23 %31 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt deleted file mode 100644 index 2184523..0000000 --- a/ptx/src/test/spirv_run/rem.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "rem" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %39 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %16 = OpSMod %uint %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rsqrt.spvtxt b/ptx/src/test/spirv_run/rsqrt.spvtxt deleted file mode 100644 index 6c87113..0000000 --- a/ptx/src/test/spirv_run/rsqrt.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "rsqrt" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %double = OpTypeFloat 64 -%_ptr_Function_double = OpTypePointer Function %double -%_ptr_Generic_double = OpTypePointer Generic %double - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_double Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_double %12 - %11 = OpLoad %double %17 Aligned 8 - OpStore %6 %11 - %14 = OpLoad %double %6 - %13 = OpExtInst %double %21 rsqrt %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %double %6 - %18 = OpConvertUToPtr %_ptr_Generic_double %15 - OpStore %18 %16 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt deleted file mode 100644 index 40c0bce..0000000 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ /dev/null @@ -1,61 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %29 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "selp" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %ushort = OpTypeInt 16 0 -%_ptr_Function_ushort = OpTypePointer Function %ushort -%_ptr_Generic_ushort = OpTypePointer Generic %ushort - %ulong_2 = OpConstant %ulong 2 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %bool = OpTypeBool - %false = OpConstantFalse %bool - %1 = OpFunction %void None %32 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %27 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ushort Function - %7 = OpVariable %_ptr_Function_ushort Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_ushort %13 - %12 = OpLoad %ushort %24 Aligned 2 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 - %39 = OpBitcast %_ptr_Generic_uchar %25 - %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 - %22 = OpBitcast %_ptr_Generic_ushort %40 - %14 = OpLoad %ushort %22 Aligned 2 - OpStore %7 %14 - %17 = OpLoad %ushort %6 - %18 = OpLoad %ushort %7 - %16 = OpSelect %ushort %false %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %ushort %6 - %26 = OpConvertUToPtr %_ptr_Generic_ushort %19 - OpStore %26 %20 Aligned 2 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/selp_true.spvtxt b/ptx/src/test/spirv_run/selp_true.spvtxt deleted file mode 100644 index 81b3b5f..0000000 --- a/ptx/src/test/spirv_run/selp_true.spvtxt +++ /dev/null @@ -1,61 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %29 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "selp_true" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %ushort = OpTypeInt 16 0 -%_ptr_Function_ushort = OpTypePointer Function %ushort -%_ptr_Generic_ushort = OpTypePointer Generic %ushort - %ulong_2 = OpConstant %ulong 2 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %bool = OpTypeBool - %true = OpConstantTrue %bool - %1 = OpFunction %void None %32 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %27 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ushort Function - %7 = OpVariable %_ptr_Function_ushort Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_ushort %13 - %12 = OpLoad %ushort %24 Aligned 2 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 - %39 = OpBitcast %_ptr_Generic_uchar %25 - %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 - %22 = OpBitcast %_ptr_Generic_ushort %40 - %14 = OpLoad %ushort %22 Aligned 2 - OpStore %7 %14 - %17 = OpLoad %ushort %6 - %18 = OpLoad %ushort %7 - %16 = OpSelect %ushort %true %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %ushort %6 - %26 = OpConvertUToPtr %_ptr_Generic_ushort %19 - OpStore %26 %20 Aligned 2 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt deleted file mode 100644 index 5868881..0000000 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ /dev/null @@ -1,77 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %40 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "setp" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %43 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_8 = OpConstant %ulong 8 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_1 = OpConstant %ulong 1 - %ulong_2 = OpConstant %ulong 2 - %1 = OpFunction %void None %43 - %14 = OpFunctionParameter %ulong - %15 = OpFunctionParameter %ulong - %38 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_bool Function - OpStore %2 %14 - OpStore %3 %15 - %16 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %16 - %17 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %17 - %19 = OpLoad %ulong %4 - %35 = OpConvertUToPtr %_ptr_Generic_ulong %19 - %18 = OpLoad %ulong %35 Aligned 8 - OpStore %6 %18 - %21 = OpLoad %ulong %4 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %21 - %50 = OpBitcast %_ptr_Generic_uchar %36 - %51 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %50 %ulong_8 - %32 = OpBitcast %_ptr_Generic_ulong %51 - %20 = OpLoad %ulong %32 Aligned 8 - OpStore %7 %20 - %23 = OpLoad %ulong %6 - %24 = OpLoad %ulong %7 - %22 = OpULessThan %bool %23 %24 - OpStore %9 %22 - %25 = OpLoad %bool %9 - OpBranchConditional %25 %10 %11 - %10 = OpLabel - %26 = OpCopyObject %ulong %ulong_1 - OpStore %8 %26 - OpBranch %11 - %11 = OpLabel - %27 = OpLoad %bool %9 - OpBranchConditional %27 %13 %12 - %12 = OpLabel - %28 = OpCopyObject %ulong %ulong_2 - OpStore %8 %28 - OpBranch %13 - %13 = OpLabel - %29 = OpLoad %ulong %5 - %30 = OpLoad %ulong %8 - %37 = OpConvertUToPtr %_ptr_Generic_ulong %29 - OpStore %37 %30 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_gt.spvtxt b/ptx/src/test/spirv_run/setp_gt.spvtxt deleted file mode 100644 index e9783f5..0000000 --- a/ptx/src/test/spirv_run/setp_gt.spvtxt +++ /dev/null @@ -1,79 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %40 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "setp_gt" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %43 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %43 - %14 = OpFunctionParameter %ulong - %15 = OpFunctionParameter %ulong - %38 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - %9 = OpVariable %_ptr_Function_bool Function - OpStore %2 %14 - OpStore %3 %15 - %16 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %16 - %17 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %17 - %19 = OpLoad %ulong %4 - %35 = OpConvertUToPtr %_ptr_Generic_float %19 - %18 = OpLoad %float %35 Aligned 4 - OpStore %6 %18 - %21 = OpLoad %ulong %4 - %36 = OpConvertUToPtr %_ptr_Generic_float %21 - %52 = OpBitcast %_ptr_Generic_uchar %36 - %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 - %34 = OpBitcast %_ptr_Generic_float %53 - %20 = OpLoad %float %34 Aligned 4 - OpStore %7 %20 - %23 = OpLoad %float %6 - %24 = OpLoad %float %7 - %22 = OpFOrdGreaterThan %bool %23 %24 - OpStore %9 %22 - %25 = OpLoad %bool %9 - OpBranchConditional %25 %10 %11 - %10 = OpLabel - %27 = OpLoad %float %6 - %26 = OpCopyObject %float %27 - OpStore %8 %26 - OpBranch %11 - %11 = OpLabel - %28 = OpLoad %bool %9 - OpBranchConditional %28 %13 %12 - %12 = OpLabel - %30 = OpLoad %float %7 - %29 = OpCopyObject %float %30 - OpStore %8 %29 - OpBranch %13 - %13 = OpLabel - %31 = OpLoad %ulong %5 - %32 = OpLoad %float %8 - %37 = OpConvertUToPtr %_ptr_Generic_float %31 - OpStore %37 %32 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_leu.spvtxt b/ptx/src/test/spirv_run/setp_leu.spvtxt deleted file mode 100644 index 1d2d781..0000000 --- a/ptx/src/test/spirv_run/setp_leu.spvtxt +++ /dev/null @@ -1,79 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %40 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "setp_leu" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %43 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %43 - %14 = OpFunctionParameter %ulong - %15 = OpFunctionParameter %ulong - %38 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - %9 = OpVariable %_ptr_Function_bool Function - OpStore %2 %14 - OpStore %3 %15 - %16 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %16 - %17 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %17 - %19 = OpLoad %ulong %4 - %35 = OpConvertUToPtr %_ptr_Generic_float %19 - %18 = OpLoad %float %35 Aligned 4 - OpStore %6 %18 - %21 = OpLoad %ulong %4 - %36 = OpConvertUToPtr %_ptr_Generic_float %21 - %52 = OpBitcast %_ptr_Generic_uchar %36 - %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 - %34 = OpBitcast %_ptr_Generic_float %53 - %20 = OpLoad %float %34 Aligned 4 - OpStore %7 %20 - %23 = OpLoad %float %6 - %24 = OpLoad %float %7 - %22 = OpFUnordLessThanEqual %bool %23 %24 - OpStore %9 %22 - %25 = OpLoad %bool %9 - OpBranchConditional %25 %10 %11 - %10 = OpLabel - %27 = OpLoad %float %6 - %26 = OpCopyObject %float %27 - OpStore %8 %26 - OpBranch %11 - %11 = OpLabel - %28 = OpLoad %bool %9 - OpBranchConditional %28 %13 %12 - %12 = OpLabel - %30 = OpLoad %float %7 - %29 = OpCopyObject %float %30 - OpStore %8 %29 - OpBranch %13 - %13 = OpLabel - %31 = OpLoad %ulong %5 - %32 = OpLoad %float %8 - %37 = OpConvertUToPtr %_ptr_Generic_float %31 - OpStore %37 %32 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_nan.spvtxt b/ptx/src/test/spirv_run/setp_nan.spvtxt deleted file mode 100644 index 2ee333a..0000000 --- a/ptx/src/test/spirv_run/setp_nan.spvtxt +++ /dev/null @@ -1,228 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %130 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "setp_nan" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %133 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %ulong_12 = OpConstant %ulong 12 - %ulong_16 = OpConstant %ulong 16 - %ulong_20 = OpConstant %ulong 20 - %ulong_24 = OpConstant %ulong 24 - %ulong_28 = OpConstant %ulong 28 - %uint_1 = OpConstant %uint 1 - %uint_0 = OpConstant %uint 0 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %uint_1_0 = OpConstant %uint 1 - %uint_0_0 = OpConstant %uint 0 - %ulong_4_0 = OpConstant %ulong 4 - %uint_1_1 = OpConstant %uint 1 - %uint_0_1 = OpConstant %uint 0 - %ulong_8_0 = OpConstant %ulong 8 - %uint_1_2 = OpConstant %uint 1 - %uint_0_2 = OpConstant %uint 0 - %ulong_12_0 = OpConstant %ulong 12 - %1 = OpFunction %void None %133 - %32 = OpFunctionParameter %ulong - %33 = OpFunctionParameter %ulong - %128 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - %9 = OpVariable %_ptr_Function_float Function - %10 = OpVariable %_ptr_Function_float Function - %11 = OpVariable %_ptr_Function_float Function - %12 = OpVariable %_ptr_Function_float Function - %13 = OpVariable %_ptr_Function_float Function - %14 = OpVariable %_ptr_Function_uint Function - %15 = OpVariable %_ptr_Function_bool Function - OpStore %2 %32 - OpStore %3 %33 - %34 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %34 - %35 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %35 - %37 = OpLoad %ulong %4 - %116 = OpConvertUToPtr %_ptr_Generic_float %37 - %36 = OpLoad %float %116 Aligned 4 - OpStore %6 %36 - %39 = OpLoad %ulong %4 - %117 = OpConvertUToPtr %_ptr_Generic_float %39 - %144 = OpBitcast %_ptr_Generic_uchar %117 - %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 - %89 = OpBitcast %_ptr_Generic_float %145 - %38 = OpLoad %float %89 Aligned 4 - OpStore %7 %38 - %41 = OpLoad %ulong %4 - %118 = OpConvertUToPtr %_ptr_Generic_float %41 - %146 = OpBitcast %_ptr_Generic_uchar %118 - %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 - %91 = OpBitcast %_ptr_Generic_float %147 - %40 = OpLoad %float %91 Aligned 4 - OpStore %8 %40 - %43 = OpLoad %ulong %4 - %119 = OpConvertUToPtr %_ptr_Generic_float %43 - %148 = OpBitcast %_ptr_Generic_uchar %119 - %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 - %93 = OpBitcast %_ptr_Generic_float %149 - %42 = OpLoad %float %93 Aligned 4 - OpStore %9 %42 - %45 = OpLoad %ulong %4 - %120 = OpConvertUToPtr %_ptr_Generic_float %45 - %150 = OpBitcast %_ptr_Generic_uchar %120 - %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 - %95 = OpBitcast %_ptr_Generic_float %151 - %44 = OpLoad %float %95 Aligned 4 - OpStore %10 %44 - %47 = OpLoad %ulong %4 - %121 = OpConvertUToPtr %_ptr_Generic_float %47 - %152 = OpBitcast %_ptr_Generic_uchar %121 - %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 - %97 = OpBitcast %_ptr_Generic_float %153 - %46 = OpLoad %float %97 Aligned 4 - OpStore %11 %46 - %49 = OpLoad %ulong %4 - %122 = OpConvertUToPtr %_ptr_Generic_float %49 - %154 = OpBitcast %_ptr_Generic_uchar %122 - %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 - %99 = OpBitcast %_ptr_Generic_float %155 - %48 = OpLoad %float %99 Aligned 4 - OpStore %12 %48 - %51 = OpLoad %ulong %4 - %123 = OpConvertUToPtr %_ptr_Generic_float %51 - %156 = OpBitcast %_ptr_Generic_uchar %123 - %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 - %101 = OpBitcast %_ptr_Generic_float %157 - %50 = OpLoad %float %101 Aligned 4 - OpStore %13 %50 - %53 = OpLoad %float %6 - %54 = OpLoad %float %7 - %158 = OpIsNan %bool %53 - %159 = OpIsNan %bool %54 - %52 = OpLogicalOr %bool %158 %159 - OpStore %15 %52 - %55 = OpLoad %bool %15 - OpBranchConditional %55 %16 %17 - %16 = OpLabel - %56 = OpCopyObject %uint %uint_1 - OpStore %14 %56 - OpBranch %17 - %17 = OpLabel - %57 = OpLoad %bool %15 - OpBranchConditional %57 %19 %18 - %18 = OpLabel - %58 = OpCopyObject %uint %uint_0 - OpStore %14 %58 - OpBranch %19 - %19 = OpLabel - %59 = OpLoad %ulong %5 - %60 = OpLoad %uint %14 - %124 = OpConvertUToPtr %_ptr_Generic_uint %59 - OpStore %124 %60 Aligned 4 - %62 = OpLoad %float %8 - %63 = OpLoad %float %9 - %161 = OpIsNan %bool %62 - %162 = OpIsNan %bool %63 - %61 = OpLogicalOr %bool %161 %162 - OpStore %15 %61 - %64 = OpLoad %bool %15 - OpBranchConditional %64 %20 %21 - %20 = OpLabel - %65 = OpCopyObject %uint %uint_1_0 - OpStore %14 %65 - OpBranch %21 - %21 = OpLabel - %66 = OpLoad %bool %15 - OpBranchConditional %66 %23 %22 - %22 = OpLabel - %67 = OpCopyObject %uint %uint_0_0 - OpStore %14 %67 - OpBranch %23 - %23 = OpLabel - %68 = OpLoad %ulong %5 - %69 = OpLoad %uint %14 - %125 = OpConvertUToPtr %_ptr_Generic_uint %68 - %163 = OpBitcast %_ptr_Generic_uchar %125 - %164 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %163 %ulong_4_0 - %107 = OpBitcast %_ptr_Generic_uint %164 - OpStore %107 %69 Aligned 4 - %71 = OpLoad %float %10 - %72 = OpLoad %float %11 - %165 = OpIsNan %bool %71 - %166 = OpIsNan %bool %72 - %70 = OpLogicalOr %bool %165 %166 - OpStore %15 %70 - %73 = OpLoad %bool %15 - OpBranchConditional %73 %24 %25 - %24 = OpLabel - %74 = OpCopyObject %uint %uint_1_1 - OpStore %14 %74 - OpBranch %25 - %25 = OpLabel - %75 = OpLoad %bool %15 - OpBranchConditional %75 %27 %26 - %26 = OpLabel - %76 = OpCopyObject %uint %uint_0_1 - OpStore %14 %76 - OpBranch %27 - %27 = OpLabel - %77 = OpLoad %ulong %5 - %78 = OpLoad %uint %14 - %126 = OpConvertUToPtr %_ptr_Generic_uint %77 - %167 = OpBitcast %_ptr_Generic_uchar %126 - %168 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %167 %ulong_8_0 - %111 = OpBitcast %_ptr_Generic_uint %168 - OpStore %111 %78 Aligned 4 - %80 = OpLoad %float %12 - %81 = OpLoad %float %13 - %169 = OpIsNan %bool %80 - %170 = OpIsNan %bool %81 - %79 = OpLogicalOr %bool %169 %170 - OpStore %15 %79 - %82 = OpLoad %bool %15 - OpBranchConditional %82 %28 %29 - %28 = OpLabel - %83 = OpCopyObject %uint %uint_1_2 - OpStore %14 %83 - OpBranch %29 - %29 = OpLabel - %84 = OpLoad %bool %15 - OpBranchConditional %84 %31 %30 - %30 = OpLabel - %85 = OpCopyObject %uint %uint_0_2 - OpStore %14 %85 - OpBranch %31 - %31 = OpLabel - %86 = OpLoad %ulong %5 - %87 = OpLoad %uint %14 - %127 = OpConvertUToPtr %_ptr_Generic_uint %86 - %171 = OpBitcast %_ptr_Generic_uchar %127 - %172 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %171 %ulong_12_0 - %115 = OpBitcast %_ptr_Generic_uint %172 - OpStore %115 %87 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_num.spvtxt b/ptx/src/test/spirv_run/setp_num.spvtxt deleted file mode 100644 index c576a50..0000000 --- a/ptx/src/test/spirv_run/setp_num.spvtxt +++ /dev/null @@ -1,240 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %130 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "setp_num" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %133 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_float = OpTypePointer Generic %float - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %ulong_8 = OpConstant %ulong 8 - %ulong_12 = OpConstant %ulong 12 - %ulong_16 = OpConstant %ulong 16 - %ulong_20 = OpConstant %ulong 20 - %ulong_24 = OpConstant %ulong 24 - %ulong_28 = OpConstant %ulong 28 - %true = OpConstantTrue %bool - %false = OpConstantFalse %bool - %uint_2 = OpConstant %uint 2 - %uint_0 = OpConstant %uint 0 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %true_0 = OpConstantTrue %bool - %false_0 = OpConstantFalse %bool - %uint_2_0 = OpConstant %uint 2 - %uint_0_0 = OpConstant %uint 0 - %ulong_4_0 = OpConstant %ulong 4 - %true_1 = OpConstantTrue %bool - %false_1 = OpConstantFalse %bool - %uint_2_1 = OpConstant %uint 2 - %uint_0_1 = OpConstant %uint 0 - %ulong_8_0 = OpConstant %ulong 8 - %true_2 = OpConstantTrue %bool - %false_2 = OpConstantFalse %bool - %uint_2_2 = OpConstant %uint 2 - %uint_0_2 = OpConstant %uint 0 - %ulong_12_0 = OpConstant %ulong 12 - %1 = OpFunction %void None %133 - %32 = OpFunctionParameter %ulong - %33 = OpFunctionParameter %ulong - %128 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - %7 = OpVariable %_ptr_Function_float Function - %8 = OpVariable %_ptr_Function_float Function - %9 = OpVariable %_ptr_Function_float Function - %10 = OpVariable %_ptr_Function_float Function - %11 = OpVariable %_ptr_Function_float Function - %12 = OpVariable %_ptr_Function_float Function - %13 = OpVariable %_ptr_Function_float Function - %14 = OpVariable %_ptr_Function_uint Function - %15 = OpVariable %_ptr_Function_bool Function - OpStore %2 %32 - OpStore %3 %33 - %34 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %34 - %35 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %35 - %37 = OpLoad %ulong %4 - %116 = OpConvertUToPtr %_ptr_Generic_float %37 - %36 = OpLoad %float %116 Aligned 4 - OpStore %6 %36 - %39 = OpLoad %ulong %4 - %117 = OpConvertUToPtr %_ptr_Generic_float %39 - %144 = OpBitcast %_ptr_Generic_uchar %117 - %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 - %89 = OpBitcast %_ptr_Generic_float %145 - %38 = OpLoad %float %89 Aligned 4 - OpStore %7 %38 - %41 = OpLoad %ulong %4 - %118 = OpConvertUToPtr %_ptr_Generic_float %41 - %146 = OpBitcast %_ptr_Generic_uchar %118 - %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 - %91 = OpBitcast %_ptr_Generic_float %147 - %40 = OpLoad %float %91 Aligned 4 - OpStore %8 %40 - %43 = OpLoad %ulong %4 - %119 = OpConvertUToPtr %_ptr_Generic_float %43 - %148 = OpBitcast %_ptr_Generic_uchar %119 - %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 - %93 = OpBitcast %_ptr_Generic_float %149 - %42 = OpLoad %float %93 Aligned 4 - OpStore %9 %42 - %45 = OpLoad %ulong %4 - %120 = OpConvertUToPtr %_ptr_Generic_float %45 - %150 = OpBitcast %_ptr_Generic_uchar %120 - %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 - %95 = OpBitcast %_ptr_Generic_float %151 - %44 = OpLoad %float %95 Aligned 4 - OpStore %10 %44 - %47 = OpLoad %ulong %4 - %121 = OpConvertUToPtr %_ptr_Generic_float %47 - %152 = OpBitcast %_ptr_Generic_uchar %121 - %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 - %97 = OpBitcast %_ptr_Generic_float %153 - %46 = OpLoad %float %97 Aligned 4 - OpStore %11 %46 - %49 = OpLoad %ulong %4 - %122 = OpConvertUToPtr %_ptr_Generic_float %49 - %154 = OpBitcast %_ptr_Generic_uchar %122 - %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 - %99 = OpBitcast %_ptr_Generic_float %155 - %48 = OpLoad %float %99 Aligned 4 - OpStore %12 %48 - %51 = OpLoad %ulong %4 - %123 = OpConvertUToPtr %_ptr_Generic_float %51 - %156 = OpBitcast %_ptr_Generic_uchar %123 - %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 - %101 = OpBitcast %_ptr_Generic_float %157 - %50 = OpLoad %float %101 Aligned 4 - OpStore %13 %50 - %53 = OpLoad %float %6 - %54 = OpLoad %float %7 - %158 = OpIsNan %bool %53 - %159 = OpIsNan %bool %54 - %160 = OpLogicalOr %bool %158 %159 - %52 = OpSelect %bool %160 %false %true - OpStore %15 %52 - %55 = OpLoad %bool %15 - OpBranchConditional %55 %16 %17 - %16 = OpLabel - %56 = OpCopyObject %uint %uint_2 - OpStore %14 %56 - OpBranch %17 - %17 = OpLabel - %57 = OpLoad %bool %15 - OpBranchConditional %57 %19 %18 - %18 = OpLabel - %58 = OpCopyObject %uint %uint_0 - OpStore %14 %58 - OpBranch %19 - %19 = OpLabel - %59 = OpLoad %ulong %5 - %60 = OpLoad %uint %14 - %124 = OpConvertUToPtr %_ptr_Generic_uint %59 - OpStore %124 %60 Aligned 4 - %62 = OpLoad %float %8 - %63 = OpLoad %float %9 - %164 = OpIsNan %bool %62 - %165 = OpIsNan %bool %63 - %166 = OpLogicalOr %bool %164 %165 - %61 = OpSelect %bool %166 %false_0 %true_0 - OpStore %15 %61 - %64 = OpLoad %bool %15 - OpBranchConditional %64 %20 %21 - %20 = OpLabel - %65 = OpCopyObject %uint %uint_2_0 - OpStore %14 %65 - OpBranch %21 - %21 = OpLabel - %66 = OpLoad %bool %15 - OpBranchConditional %66 %23 %22 - %22 = OpLabel - %67 = OpCopyObject %uint %uint_0_0 - OpStore %14 %67 - OpBranch %23 - %23 = OpLabel - %68 = OpLoad %ulong %5 - %69 = OpLoad %uint %14 - %125 = OpConvertUToPtr %_ptr_Generic_uint %68 - %169 = OpBitcast %_ptr_Generic_uchar %125 - %170 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %169 %ulong_4_0 - %107 = OpBitcast %_ptr_Generic_uint %170 - OpStore %107 %69 Aligned 4 - %71 = OpLoad %float %10 - %72 = OpLoad %float %11 - %171 = OpIsNan %bool %71 - %172 = OpIsNan %bool %72 - %173 = OpLogicalOr %bool %171 %172 - %70 = OpSelect %bool %173 %false_1 %true_1 - OpStore %15 %70 - %73 = OpLoad %bool %15 - OpBranchConditional %73 %24 %25 - %24 = OpLabel - %74 = OpCopyObject %uint %uint_2_1 - OpStore %14 %74 - OpBranch %25 - %25 = OpLabel - %75 = OpLoad %bool %15 - OpBranchConditional %75 %27 %26 - %26 = OpLabel - %76 = OpCopyObject %uint %uint_0_1 - OpStore %14 %76 - OpBranch %27 - %27 = OpLabel - %77 = OpLoad %ulong %5 - %78 = OpLoad %uint %14 - %126 = OpConvertUToPtr %_ptr_Generic_uint %77 - %176 = OpBitcast %_ptr_Generic_uchar %126 - %177 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %176 %ulong_8_0 - %111 = OpBitcast %_ptr_Generic_uint %177 - OpStore %111 %78 Aligned 4 - %80 = OpLoad %float %12 - %81 = OpLoad %float %13 - %178 = OpIsNan %bool %80 - %179 = OpIsNan %bool %81 - %180 = OpLogicalOr %bool %178 %179 - %79 = OpSelect %bool %180 %false_2 %true_2 - OpStore %15 %79 - %82 = OpLoad %bool %15 - OpBranchConditional %82 %28 %29 - %28 = OpLabel - %83 = OpCopyObject %uint %uint_2_2 - OpStore %14 %83 - OpBranch %29 - %29 = OpLabel - %84 = OpLoad %bool %15 - OpBranchConditional %84 %31 %30 - %30 = OpLabel - %85 = OpCopyObject %uint %uint_0_2 - OpStore %14 %85 - OpBranch %31 - %31 = OpLabel - %86 = OpLoad %ulong %5 - %87 = OpLoad %uint %14 - %127 = OpConvertUToPtr %_ptr_Generic_uint %86 - %183 = OpBitcast %_ptr_Generic_uchar %127 - %184 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %183 %ulong_12_0 - %115 = OpBitcast %_ptr_Generic_uint %184 - OpStore %115 %87 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt deleted file mode 100644 index 787a71c..0000000 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ /dev/null @@ -1,73 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %32 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_ptr_32" %4 - OpExecutionMode %1 ContractionOff - OpDecorate %4 Alignment 4 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_128 = OpConstant %uint 128 -%_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 -%_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup - %ulong = OpTypeInt 64 0 - %40 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %ulong_0 = OpConstant %ulong 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpFunction %void None %40 - %10 = OpFunctionParameter %ulong - %11 = OpFunctionParameter %ulong - %30 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_uint Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %10 - OpStore %3 %11 - %12 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %12 - %13 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %13 - %25 = OpConvertPtrToU %uint %4 - %14 = OpCopyObject %uint %25 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - %15 = OpLoad %ulong %26 Aligned 8 - OpStore %8 %15 - %17 = OpLoad %uint %7 - %18 = OpLoad %ulong %8 - %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 - OpStore %27 %18 Aligned 8 - %20 = OpLoad %uint %7 - %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 - %46 = OpBitcast %_ptr_Workgroup_uchar %28 - %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 - %24 = OpBitcast %_ptr_Workgroup_ulong %47 - %19 = OpLoad %ulong %24 Aligned 8 - OpStore %9 %19 - %21 = OpLoad %ulong %6 - %22 = OpLoad %ulong %9 - %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 - OpStore %29 %22 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt deleted file mode 100644 index 14926ef..0000000 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ /dev/null @@ -1,64 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 - OpExecutionMode %2 ContractionOff - OpDecorate %1 Alignment 4 - %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpVariable %_ptr_Workgroup_uchar Workgroup - %ulong = OpTypeInt 64 0 - %35 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %35 - %10 = OpFunctionParameter %ulong - %11 = OpFunctionParameter %ulong - %28 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - OpStore %3 %10 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %13 = OpLoad %ulong %4 Aligned 8 - OpStore %6 %13 - %23 = OpConvertPtrToU %ulong %1 - %14 = OpCopyObject %ulong %23 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - %15 = OpLoad %ulong %24 Aligned 8 - OpStore %8 %15 - %17 = OpLoad %ulong %7 - %18 = OpLoad %ulong %8 - %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 - OpStore %25 %18 Aligned 8 - %20 = OpLoad %ulong %7 - %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 - %19 = OpLoad %ulong %26 Aligned 8 - OpStore %9 %19 - %21 = OpLoad %ulong %6 - %22 = OpLoad %ulong %9 - %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 - OpStore %27 %22 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt deleted file mode 100644 index 90fc156..0000000 --- a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt +++ /dev/null @@ -1,118 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %61 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %27 "shared_unify_extern" %1 %2 - OpExecutionMode %27 ContractionOff - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup - %uint_4 = OpConstant %uint 4 -%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 -%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4 - %2 = OpVariable %_ptr_Workgroup__arr_uint_uint_4 Workgroup - %ulong = OpTypeInt 64 0 - %uint_4_0 = OpConstant %uint 4 - %70 = OpTypeFunction %ulong %_ptr_Workgroup_uint %_ptr_Workgroup__arr_uint_uint_4 - %uint_4_1 = OpConstant %uint 4 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_4_2 = OpConstant %uint 4 - %75 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup__arr_uint_uint_4 - %uint_4_3 = OpConstant %uint 4 - %77 = OpTypeFunction %void %ulong %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_8 = OpConstant %ulong 8 - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %3 = OpFunction %ulong None %70 - %57 = OpFunctionParameter %_ptr_Workgroup_uint - %58 = OpFunctionParameter %_ptr_Workgroup__arr_uint_uint_4 - %16 = OpLabel - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %13 = OpBitcast %_ptr_Workgroup_ulong %58 - %7 = OpLoad %ulong %13 Aligned 8 - OpStore %5 %7 - %14 = OpBitcast %_ptr_Workgroup_ulong %57 - %8 = OpLoad %ulong %14 Aligned 8 - OpStore %6 %8 - %10 = OpLoad %ulong %6 - %11 = OpLoad %ulong %5 - %15 = OpIAdd %ulong %10 %11 - %9 = OpCopyObject %ulong %15 - OpStore %4 %9 - %12 = OpLoad %ulong %4 - OpReturnValue %12 - OpFunctionEnd - %17 = OpFunction %ulong None %75 - %20 = OpFunctionParameter %ulong - %59 = OpFunctionParameter %_ptr_Workgroup_uint - %60 = OpFunctionParameter %_ptr_Workgroup__arr_uint_uint_4 - %26 = OpLabel - %19 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - OpStore %19 %20 - %21 = OpLoad %ulong %19 - %24 = OpBitcast %_ptr_Workgroup_ulong %59 - %25 = OpCopyObject %ulong %21 - OpStore %24 %25 Aligned 8 - %22 = OpFunctionCall %ulong %3 %59 %60 - OpStore %18 %22 - %23 = OpLoad %ulong %18 - OpReturnValue %23 - OpFunctionEnd - %27 = OpFunction %void None %77 - %34 = OpFunctionParameter %ulong - %35 = OpFunctionParameter %ulong - %55 = OpLabel - %28 = OpVariable %_ptr_Function_ulong Function - %29 = OpVariable %_ptr_Function_ulong Function - %30 = OpVariable %_ptr_Function_ulong Function - %31 = OpVariable %_ptr_Function_ulong Function - %32 = OpVariable %_ptr_Function_ulong Function - %33 = OpVariable %_ptr_Function_ulong Function - OpStore %28 %34 - OpStore %29 %35 - %36 = OpLoad %ulong %28 Aligned 8 - OpStore %30 %36 - %37 = OpLoad %ulong %29 Aligned 8 - OpStore %31 %37 - %39 = OpLoad %ulong %30 - %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %39 - %38 = OpLoad %ulong %49 Aligned 8 - OpStore %32 %38 - %41 = OpLoad %ulong %30 - %50 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %41 - %81 = OpBitcast %_ptr_CrossWorkgroup_uchar %50 - %82 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %81 %ulong_8 - %48 = OpBitcast %_ptr_CrossWorkgroup_ulong %82 - %40 = OpLoad %ulong %48 Aligned 8 - OpStore %33 %40 - %42 = OpLoad %ulong %33 - %51 = OpBitcast %_ptr_Workgroup_ulong %2 - OpStore %51 %42 Aligned 8 - %44 = OpLoad %ulong %32 - %53 = OpCopyObject %ulong %44 - %52 = OpFunctionCall %ulong %17 %53 %1 %2 - %43 = OpCopyObject %ulong %52 - OpStore %33 %43 - %45 = OpLoad %ulong %31 - %46 = OpLoad %ulong %33 - %54 = OpConvertUToPtr %_ptr_Generic_ulong %45 - OpStore %54 %46 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_unify_local.spvtxt b/ptx/src/test/spirv_run/shared_unify_local.spvtxt deleted file mode 100644 index dc00c2f..0000000 --- a/ptx/src/test/spirv_run/shared_unify_local.spvtxt +++ /dev/null @@ -1,117 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %64 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "shared_unify_local" %1 %5 - OpExecutionMode %31 ContractionOff - OpDecorate %5 Alignment 4 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup - %ulong = OpTypeInt 64 0 -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %5 = OpVariable %_ptr_Workgroup_ulong Workgroup - %70 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %72 = OpTypeFunction %ulong %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong - %73 = OpTypeFunction %void %ulong %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_8 = OpConstant %ulong 8 - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %2 = OpFunction %ulong None %70 - %7 = OpFunctionParameter %ulong - %60 = OpFunctionParameter %_ptr_Workgroup_uint - %61 = OpFunctionParameter %_ptr_Workgroup_ulong - %17 = OpLabel - %4 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - OpStore %4 %7 - %8 = OpLoad %ulong %4 - OpStore %61 %8 Aligned 8 - %9 = OpLoad %ulong %61 Aligned 8 - OpStore %6 %9 - %15 = OpBitcast %_ptr_Workgroup_ulong %60 - %10 = OpLoad %ulong %15 Aligned 8 - OpStore %4 %10 - %12 = OpLoad %ulong %4 - %13 = OpLoad %ulong %6 - %16 = OpIAdd %ulong %12 %13 - %11 = OpCopyObject %ulong %16 - OpStore %3 %11 - %14 = OpLoad %ulong %3 - OpReturnValue %14 - OpFunctionEnd - %18 = OpFunction %ulong None %72 - %22 = OpFunctionParameter %ulong - %23 = OpFunctionParameter %ulong - %62 = OpFunctionParameter %_ptr_Workgroup_uint - %63 = OpFunctionParameter %_ptr_Workgroup_ulong - %30 = OpLabel - %20 = OpVariable %_ptr_Function_ulong Function - %21 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_ulong Function - OpStore %20 %22 - OpStore %21 %23 - %24 = OpLoad %ulong %20 - %28 = OpBitcast %_ptr_Workgroup_ulong %62 - %29 = OpCopyObject %ulong %24 - OpStore %28 %29 Aligned 8 - %26 = OpLoad %ulong %21 - %25 = OpFunctionCall %ulong %2 %26 %62 %63 - OpStore %19 %25 - %27 = OpLoad %ulong %19 - OpReturnValue %27 - OpFunctionEnd - %31 = OpFunction %void None %73 - %38 = OpFunctionParameter %ulong - %39 = OpFunctionParameter %ulong - %58 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function - %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_ulong Function - %37 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %38 - OpStore %33 %39 - %40 = OpLoad %ulong %32 Aligned 8 - OpStore %34 %40 - %41 = OpLoad %ulong %33 Aligned 8 - OpStore %35 %41 - %43 = OpLoad %ulong %34 - %53 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %43 - %42 = OpLoad %ulong %53 Aligned 8 - OpStore %36 %42 - %45 = OpLoad %ulong %34 - %54 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %45 - %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %54 - %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %ulong_8 - %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %78 - %44 = OpLoad %ulong %52 Aligned 8 - OpStore %37 %44 - %47 = OpLoad %ulong %36 - %48 = OpLoad %ulong %37 - %56 = OpCopyObject %ulong %47 - %55 = OpFunctionCall %ulong %18 %56 %48 %1 %5 - %46 = OpCopyObject %ulong %55 - OpStore %37 %46 - %49 = OpLoad %ulong %35 - %50 = OpLoad %ulong %37 - %57 = OpConvertUToPtr %_ptr_Generic_ulong %49 - OpStore %57 %50 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt deleted file mode 100644 index fbbfe4a..0000000 --- a/ptx/src/test/spirv_run/shared_variable.spvtxt +++ /dev/null @@ -1,61 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - OpExtension "SPV_KHR_float_controls" - OpExtension "SPV_KHR_no_integer_wrap_decoration" - %25 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_variable" %4 - OpExecutionMode %1 ContractionOff - OpDecorate %4 Alignment 4 - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %uint_128 = OpConstant %uint 128 -%_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 -%_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup - %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %1 = OpFunction %void None %33 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %5 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %6 %12 - %14 = OpLoad %ulong %5 - %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %14 - %13 = OpLoad %ulong %19 Aligned 8 - OpStore %7 %13 - %15 = OpLoad %ulong %7 - %20 = OpBitcast %_ptr_Workgroup_ulong %4 - OpStore %20 %15 Aligned 8 - %21 = OpBitcast %_ptr_Workgroup_ulong %4 - %16 = OpLoad %ulong %21 Aligned 8 - OpStore %8 %16 - %17 = OpLoad %ulong %6 - %18 = OpLoad %ulong %8 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 - OpStore %22 %18 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt deleted file mode 100644 index 2a1249e..0000000 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ /dev/null @@ -1,51 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %25 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shl" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %uint = OpTypeInt 32 0 - %uint_2 = OpConstant %uint 2 - %1 = OpFunction %void None %28 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %21 = OpCopyObject %ulong %15 - %32 = OpUConvert %ulong %uint_2 - %20 = OpShiftLeftLogical %ulong %21 %32 - %14 = OpCopyObject %ulong %20 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %22 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shl_link_hack.ptx b/ptx/src/test/spirv_run/shl_link_hack.ptx deleted file mode 100644 index a32555c..0000000 --- a/ptx/src/test/spirv_run/shl_link_hack.ptx +++ /dev/null @@ -1,30 +0,0 @@ -// HACK ALERT -// This test is for testing workaround for a bug in IGC where linking fails -// if there is shl/shr with different width of value and shift - -.version 6.5 -.target sm_30 -.address_size 64 - -.visible .entry shl_link_hack( - .param .u64 input, - .param .u64 output -) -{ - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 temp; - .reg .u64 temp2; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - // Here only to trigger linking - .reg .u32 unused; - atom.inc.u32 unused, [out_addr], 2000000; - - ld.u64 temp, [in_addr]; - shl.b64 temp2, temp, 2; - st.u64 [out_addr], temp2; - ret; -} diff --git a/ptx/src/test/spirv_run/shl_link_hack.spvtxt b/ptx/src/test/spirv_run/shl_link_hack.spvtxt deleted file mode 100644 index 7e53af8..0000000 --- a/ptx/src/test/spirv_run/shl_link_hack.spvtxt +++ /dev/null @@ -1,65 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %34 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shl_link_hack" - OpDecorate %29 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %38 = OpTypeFunction %uint %_ptr_Generic_uint %uint - %ulong = OpTypeInt 64 0 - %40 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function_uint = OpTypePointer Function %uint -%uint_2000000 = OpConstant %uint 2000000 -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %uint_2 = OpConstant %uint 2 - %29 = OpFunction %uint None %38 - %31 = OpFunctionParameter %_ptr_Generic_uint - %32 = OpFunctionParameter %uint - OpFunctionEnd - %1 = OpFunction %void None %40 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %5 - %23 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpFunctionCall %uint %29 %23 %uint_2000000 - OpStore %8 %13 - %16 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %16 - %15 = OpLoad %ulong %24 Aligned 8 - OpStore %6 %15 - %18 = OpLoad %ulong %6 - %26 = OpCopyObject %ulong %18 - %44 = OpUConvert %ulong %uint_2 - %25 = OpShiftLeftLogical %ulong %26 %44 - %17 = OpCopyObject %ulong %25 - OpStore %7 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %ulong %7 - %27 = OpConvertUToPtr %_ptr_Generic_ulong %19 - OpStore %27 %20 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shr.spvtxt b/ptx/src/test/spirv_run/shr.spvtxt deleted file mode 100644 index 249e71a..0000000 --- a/ptx/src/test/spirv_run/shr.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %22 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shr" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %25 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %uint_1 = OpConstant %uint 1 - %1 = OpFunction %void None %25 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %20 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_uint %12 - %11 = OpLoad %uint %18 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %uint %6 - %13 = OpShiftRightArithmetic %uint %14 %uint_1 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %uint %6 - %19 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %19 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sign_extend.spvtxt b/ptx/src/test/spirv_run/sign_extend.spvtxt deleted file mode 100644 index 5ceffed..0000000 --- a/ptx/src/test/spirv_run/sign_extend.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %20 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "sign_extend" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %23 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint - %ushort = OpTypeInt 16 0 -%_ptr_Generic_ushort = OpTypePointer Generic %ushort -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %23 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %18 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %16 = OpConvertUToPtr %_ptr_Generic_ushort %12 - %15 = OpLoad %ushort %16 Aligned 2 - %11 = OpSConvert %uint %15 - OpStore %6 %11 - %13 = OpLoad %ulong %5 - %14 = OpLoad %uint %6 - %17 = OpConvertUToPtr %_ptr_Generic_uint %13 - OpStore %17 %14 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sin.spvtxt b/ptx/src/test/spirv_run/sin.spvtxt deleted file mode 100644 index 6dd3e53..0000000 --- a/ptx/src/test/spirv_run/sin.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "sin" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 sin %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sqrt.spvtxt b/ptx/src/test/spirv_run/sqrt.spvtxt deleted file mode 100644 index 1c65aa3..0000000 --- a/ptx/src/test/spirv_run/sqrt.spvtxt +++ /dev/null @@ -1,48 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "sqrt" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_Generic_float = OpTypePointer Generic %float - %1 = OpFunction %void None %24 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %19 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %9 - %10 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_Generic_float %12 - %11 = OpLoad %float %17 Aligned 4 - OpStore %6 %11 - %14 = OpLoad %float %6 - %13 = OpExtInst %float %21 sqrt %14 - OpStore %6 %13 - %15 = OpLoad %ulong %5 - %16 = OpLoad %float %6 - %18 = OpConvertUToPtr %_ptr_Generic_float %15 - OpStore %18 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt deleted file mode 100644 index e2d4db6..0000000 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ /dev/null @@ -1,93 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %56 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "stateful_ld_st_ntid" - OpExecutionMode %1 ContractionOff - OpDecorate %12 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %60 = OpTypeFunction %uint %uchar -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %62 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar -%_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uchar_0 = OpConstant %uchar 0 -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %12 = OpFunction %uint None %60 - %14 = OpFunctionParameter %uchar - OpFunctionEnd - %1 = OpFunction %void None %62 - %25 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %26 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %54 = OpLabel - %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - OpStore %17 %25 - OpStore %18 %26 - %47 = OpBitcast %_ptr_Function_ulong %17 - %46 = OpLoad %ulong %47 Aligned 8 - %19 = OpCopyObject %ulong %46 - %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 - OpStore %15 %27 - %49 = OpBitcast %_ptr_Function_ulong %18 - %48 = OpLoad %ulong %49 Aligned 8 - %20 = OpCopyObject %ulong %48 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 - OpStore %16 %28 - %29 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %22 = OpConvertPtrToU %ulong %29 - %21 = OpCopyObject %ulong %22 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %21 - OpStore %15 %30 - %31 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %24 = OpConvertPtrToU %ulong %31 - %23 = OpCopyObject %ulong %24 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 - OpStore %16 %32 - %11 = OpFunctionCall %uint %12 %uchar_0 - %33 = OpCopyObject %uint %11 - OpStore %6 %33 - %35 = OpLoad %uint %6 - %67 = OpBitcast %uint %35 - %34 = OpUConvert %ulong %67 - OpStore %7 %34 - %37 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %38 = OpLoad %ulong %7 - %50 = OpCopyObject %ulong %38 - %68 = OpBitcast %_ptr_CrossWorkgroup_uchar %37 - %69 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %68 %50 - %36 = OpBitcast %_ptr_CrossWorkgroup_uchar %69 - OpStore %15 %36 - %40 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %41 = OpLoad %ulong %7 - %51 = OpCopyObject %ulong %41 - %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %40 - %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %51 - %39 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 - OpStore %16 %39 - %43 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %43 - %42 = OpLoad %ulong %52 Aligned 8 - OpStore %8 %42 - %44 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %45 = OpLoad %ulong %8 - %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %44 - OpStore %53 %45 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt deleted file mode 100644 index 5da0ef3..0000000 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ /dev/null @@ -1,97 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %64 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" - OpExecutionMode %1 ContractionOff - OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %68 = OpTypeFunction %uint %uchar -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %70 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar -%_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uchar_0 = OpConstant %uchar 0 -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %16 = OpFunction %uint None %68 - %18 = OpFunctionParameter %uchar - OpFunctionEnd - %1 = OpFunction %void None %70 - %33 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %34 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %62 = OpLabel - %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %10 = OpVariable %_ptr_Function_uint Function - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - OpStore %25 %33 - OpStore %26 %34 - %55 = OpBitcast %_ptr_Function_ulong %25 - %54 = OpLoad %ulong %55 Aligned 8 - %27 = OpCopyObject %ulong %54 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27 - OpStore %19 %35 - %57 = OpBitcast %_ptr_Function_ulong %26 - %56 = OpLoad %ulong %57 Aligned 8 - %28 = OpCopyObject %ulong %56 - %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28 - OpStore %22 %36 - %37 = OpLoad %_ptr_CrossWorkgroup_uchar %19 - %30 = OpConvertPtrToU %ulong %37 - %29 = OpCopyObject %ulong %30 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29 - OpStore %20 %38 - %39 = OpLoad %_ptr_CrossWorkgroup_uchar %22 - %32 = OpConvertPtrToU %ulong %39 - %31 = OpCopyObject %ulong %32 - %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31 - OpStore %23 %40 - %15 = OpFunctionCall %uint %16 %uchar_0 - %41 = OpCopyObject %uint %15 - OpStore %10 %41 - %43 = OpLoad %uint %10 - %75 = OpBitcast %uint %43 - %42 = OpUConvert %ulong %75 - OpStore %11 %42 - %45 = OpLoad %_ptr_CrossWorkgroup_uchar %20 - %46 = OpLoad %ulong %11 - %58 = OpCopyObject %ulong %46 - %76 = OpBitcast %_ptr_CrossWorkgroup_uchar %45 - %77 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %76 %58 - %44 = OpBitcast %_ptr_CrossWorkgroup_uchar %77 - OpStore %21 %44 - %48 = OpLoad %_ptr_CrossWorkgroup_uchar %23 - %49 = OpLoad %ulong %11 - %59 = OpCopyObject %ulong %49 - %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %48 - %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %59 - %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 - OpStore %24 %47 - %51 = OpLoad %_ptr_CrossWorkgroup_uchar %21 - %60 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 - %50 = OpLoad %ulong %60 Aligned 8 - OpStore %12 %50 - %52 = OpLoad %_ptr_CrossWorkgroup_uchar %24 - %53 = OpLoad %ulong %12 - %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %52 - OpStore %61 %53 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt deleted file mode 100644 index 0ef5d28..0000000 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt +++ /dev/null @@ -1,107 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %70 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" - OpExecutionMode %1 ContractionOff - OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %uchar = OpTypeInt 8 0 - %74 = OpTypeFunction %uint %uchar -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %76 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar -%_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uchar_0 = OpConstant %uchar 0 - %ulong_0 = OpConstant %ulong 0 -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %ulong_0_0 = OpConstant %ulong 0 - %16 = OpFunction %uint None %74 - %18 = OpFunctionParameter %uchar - OpFunctionEnd - %1 = OpFunction %void None %76 - %35 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %36 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %68 = OpLabel - %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %10 = OpVariable %_ptr_Function_uint Function - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - OpStore %25 %35 - OpStore %26 %36 - %61 = OpBitcast %_ptr_Function_ulong %25 - %60 = OpLoad %ulong %61 Aligned 8 - %27 = OpCopyObject %ulong %60 - %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27 - OpStore %19 %37 - %63 = OpBitcast %_ptr_Function_ulong %26 - %62 = OpLoad %ulong %63 Aligned 8 - %28 = OpCopyObject %ulong %62 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28 - OpStore %22 %38 - %39 = OpLoad %_ptr_CrossWorkgroup_uchar %19 - %30 = OpConvertPtrToU %ulong %39 - %29 = OpCopyObject %ulong %30 - %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29 - OpStore %20 %40 - %41 = OpLoad %_ptr_CrossWorkgroup_uchar %22 - %32 = OpConvertPtrToU %ulong %41 - %31 = OpCopyObject %ulong %32 - %42 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31 - OpStore %23 %42 - %15 = OpFunctionCall %uint %16 %uchar_0 - %43 = OpCopyObject %uint %15 - OpStore %10 %43 - %45 = OpLoad %uint %10 - %81 = OpBitcast %uint %45 - %44 = OpUConvert %ulong %81 - OpStore %11 %44 - %46 = OpLoad %ulong %11 - %64 = OpCopyObject %ulong %46 - %33 = OpSNegate %ulong %64 - %48 = OpLoad %_ptr_CrossWorkgroup_uchar %20 - %82 = OpBitcast %_ptr_CrossWorkgroup_uchar %48 - %83 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %82 %33 - %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %83 - OpStore %21 %47 - %49 = OpLoad %ulong %11 - %65 = OpCopyObject %ulong %49 - %34 = OpSNegate %ulong %65 - %51 = OpLoad %_ptr_CrossWorkgroup_uchar %23 - %84 = OpBitcast %_ptr_CrossWorkgroup_uchar %51 - %85 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %84 %34 - %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %85 - OpStore %24 %50 - %53 = OpLoad %_ptr_CrossWorkgroup_uchar %21 - %66 = OpBitcast %_ptr_CrossWorkgroup_ulong %53 - %87 = OpBitcast %_ptr_CrossWorkgroup_uchar %66 - %88 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %87 %ulong_0 - %57 = OpBitcast %_ptr_CrossWorkgroup_ulong %88 - %52 = OpLoad %ulong %57 Aligned 8 - OpStore %12 %52 - %54 = OpLoad %_ptr_CrossWorkgroup_uchar %24 - %55 = OpLoad %ulong %12 - %67 = OpBitcast %_ptr_CrossWorkgroup_ulong %54 - %89 = OpBitcast %_ptr_CrossWorkgroup_uchar %67 - %90 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %89 %ulong_0_0 - %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %90 - OpStore %59 %55 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt deleted file mode 100644 index 7a142b7..0000000 --- a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt +++ /dev/null @@ -1,65 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %41 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "stateful_ld_st_simple" - %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %45 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %45 - %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %39 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %21 - OpStore %3 %22 - %14 = OpBitcast %_ptr_Function_ulong %2 - %13 = OpLoad %ulong %14 Aligned 8 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 - OpStore %9 %23 - %16 = OpBitcast %_ptr_Function_ulong %3 - %15 = OpLoad %ulong %16 Aligned 8 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 - OpStore %10 %24 - %25 = OpLoad %_ptr_CrossWorkgroup_uchar %9 - %18 = OpConvertPtrToU %ulong %25 - %34 = OpCopyObject %ulong %18 - %33 = OpCopyObject %ulong %34 - %17 = OpCopyObject %ulong %33 - %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17 - OpStore %11 %26 - %27 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %20 = OpConvertPtrToU %ulong %27 - %36 = OpCopyObject %ulong %20 - %35 = OpCopyObject %ulong %36 - %19 = OpCopyObject %ulong %35 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 - OpStore %12 %28 - %30 = OpLoad %_ptr_CrossWorkgroup_uchar %11 - %37 = OpBitcast %_ptr_CrossWorkgroup_ulong %30 - %29 = OpLoad %ulong %37 Aligned 8 - OpStore %8 %29 - %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12 - %32 = OpLoad %ulong %8 - %38 = OpBitcast %_ptr_CrossWorkgroup_ulong %31 - OpStore %38 %32 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_neg_offset.spvtxt b/ptx/src/test/spirv_run/stateful_neg_offset.spvtxt deleted file mode 100644 index 62843ca..0000000 --- a/ptx/src/test/spirv_run/stateful_neg_offset.spvtxt +++ /dev/null @@ -1,80 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "stateful_neg_offset" - %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar -%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %61 - %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %55 = OpLabel - %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %9 = OpVariable %_ptr_Function_ulong Function - OpStore %15 %29 - OpStore %16 %30 - %47 = OpBitcast %_ptr_Function_ulong %15 - %17 = OpLoad %ulong %47 Aligned 8 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17 - OpStore %10 %31 - %48 = OpBitcast %_ptr_Function_ulong %16 - %18 = OpLoad %ulong %48 Aligned 8 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 - OpStore %11 %32 - %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %20 = OpConvertPtrToU %ulong %33 - %50 = OpCopyObject %ulong %20 - %49 = OpCopyObject %ulong %50 - %19 = OpCopyObject %ulong %49 - %34 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 - OpStore %12 %34 - %35 = OpLoad %_ptr_CrossWorkgroup_uchar %11 - %22 = OpConvertPtrToU %ulong %35 - %52 = OpCopyObject %ulong %22 - %51 = OpCopyObject %ulong %52 - %21 = OpCopyObject %ulong %51 - %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %21 - OpStore %13 %36 - %37 = OpLoad %_ptr_CrossWorkgroup_uchar %12 - %24 = OpConvertPtrToU %ulong %37 - %38 = OpLoad %_ptr_CrossWorkgroup_uchar %13 - %25 = OpConvertPtrToU %ulong %38 - %23 = OpIAdd %ulong %24 %25 - %39 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 - OpStore %14 %39 - %40 = OpLoad %_ptr_CrossWorkgroup_uchar %12 - %27 = OpConvertPtrToU %ulong %40 - %41 = OpLoad %_ptr_CrossWorkgroup_uchar %13 - %28 = OpConvertPtrToU %ulong %41 - %26 = OpISub %ulong %27 %28 - %42 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 - OpStore %14 %42 - %44 = OpLoad %_ptr_CrossWorkgroup_uchar %12 - %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %44 - %43 = OpLoad %ulong %53 Aligned 8 - OpStore %9 %43 - %45 = OpLoad %_ptr_CrossWorkgroup_uchar %13 - %46 = OpLoad %ulong %9 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %45 - OpStore %54 %46 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sub.spvtxt b/ptx/src/test/spirv_run/sub.spvtxt deleted file mode 100644 index 05656dd..0000000 --- a/ptx/src/test/spirv_run/sub.spvtxt +++ /dev/null @@ -1,47 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %23 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "sub" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %26 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %26 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %21 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 Aligned 8 - OpStore %6 %12 - %15 = OpLoad %ulong %6 - %14 = OpISub %ulong %15 %ulong_1 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %ulong %7 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt deleted file mode 100644 index 8253bf9..0000000 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ /dev/null @@ -1,99 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %51 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %25 "vector" - %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v2uint = OpTypeVector %uint 2 - %55 = OpTypeFunction %v2uint %v2uint -%_ptr_Function_v2uint = OpTypePointer Function %v2uint -%_ptr_Function_uint = OpTypePointer Function %uint - %uint_0 = OpConstant %uint 0 - %uint_1 = OpConstant %uint 1 - %ulong = OpTypeInt 64 0 - %67 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %55 - %7 = OpFunctionParameter %v2uint - %24 = OpLabel - %3 = OpVariable %_ptr_Function_v2uint Function - %2 = OpVariable %_ptr_Function_v2uint Function - %4 = OpVariable %_ptr_Function_v2uint Function - %5 = OpVariable %_ptr_Function_uint Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %3 %7 - %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 - %9 = OpLoad %uint %59 - %8 = OpCopyObject %uint %9 - OpStore %5 %8 - %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 - %11 = OpLoad %uint %61 - %10 = OpCopyObject %uint %11 - OpStore %6 %10 - %13 = OpLoad %uint %5 - %14 = OpLoad %uint %6 - %12 = OpIAdd %uint %13 %14 - OpStore %6 %12 - %16 = OpLoad %uint %6 - %15 = OpCopyObject %uint %16 - %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 - OpStore %62 %15 - %18 = OpLoad %uint %6 - %17 = OpCopyObject %uint %18 - %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 - OpStore %63 %17 - %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 - %20 = OpLoad %uint %64 - %19 = OpCopyObject %uint %20 - %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 - OpStore %65 %19 - %22 = OpLoad %v2uint %4 - %21 = OpCopyObject %v2uint %22 - OpStore %2 %21 - %23 = OpLoad %v2uint %2 - OpReturnValue %23 - OpFunctionEnd - %25 = OpFunction %void None %67 - %34 = OpFunctionParameter %ulong - %35 = OpFunctionParameter %ulong - %49 = OpLabel - %26 = OpVariable %_ptr_Function_ulong Function - %27 = OpVariable %_ptr_Function_ulong Function - %28 = OpVariable %_ptr_Function_ulong Function - %29 = OpVariable %_ptr_Function_ulong Function - %30 = OpVariable %_ptr_Function_v2uint Function - %31 = OpVariable %_ptr_Function_uint Function - %32 = OpVariable %_ptr_Function_uint Function - %33 = OpVariable %_ptr_Function_ulong Function - OpStore %26 %34 - OpStore %27 %35 - %36 = OpLoad %ulong %26 Aligned 8 - OpStore %28 %36 - %37 = OpLoad %ulong %27 Aligned 8 - OpStore %29 %37 - %39 = OpLoad %ulong %28 - %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 - %38 = OpLoad %v2uint %46 Aligned 8 - OpStore %30 %38 - %41 = OpLoad %v2uint %30 - %40 = OpFunctionCall %v2uint %1 %41 - OpStore %30 %40 - %43 = OpLoad %v2uint %30 - %47 = OpBitcast %ulong %43 - %42 = OpCopyObject %ulong %47 - OpStore %33 %42 - %44 = OpLoad %ulong %29 - %45 = OpLoad %v2uint %30 - %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 - OpStore %48 %45 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector4.spvtxt b/ptx/src/test/spirv_run/vector4.spvtxt deleted file mode 100644 index 9b6349b..0000000 --- a/ptx/src/test/spirv_run/vector4.spvtxt +++ /dev/null @@ -1,56 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "vector4" - OpExecutionMode %1 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 - %v4uint = OpTypeVector %uint 4 -%_ptr_Function_v4uint = OpTypePointer Function %v4uint -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_v4uint = OpTypePointer Generic %v4uint - %uint_3 = OpConstant %uint 3 -%_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %27 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %22 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_v4uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %18 = OpConvertUToPtr %_ptr_Generic_v4uint %13 - %12 = OpLoad %v4uint %18 Aligned 16 - OpStore %6 %12 - %35 = OpInBoundsAccessChain %_ptr_Function_uint %6 %uint_3 - %15 = OpLoad %uint %35 - %20 = OpCopyObject %uint %15 - %19 = OpCopyObject %uint %20 - %14 = OpCopyObject %uint %19 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %7 - %21 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %21 %17 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt deleted file mode 100644 index 802c69b..0000000 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ /dev/null @@ -1,125 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %61 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "vector_extract" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %64 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %ushort = OpTypeInt 16 0 -%_ptr_Function_ushort = OpTypePointer Function %ushort - %v4ushort = OpTypeVector %ushort 4 -%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort - %uchar = OpTypeInt 8 0 - %v4uchar = OpTypeVector %uchar 4 -%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar - %1 = OpFunction %void None %64 - %17 = OpFunctionParameter %ulong - %18 = OpFunctionParameter %ulong - %59 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ushort Function - %7 = OpVariable %_ptr_Function_ushort Function - %8 = OpVariable %_ptr_Function_ushort Function - %9 = OpVariable %_ptr_Function_ushort Function - %10 = OpVariable %_ptr_Function_v4ushort Function - OpStore %2 %17 - OpStore %3 %18 - %19 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %19 - %20 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %20 - %21 = OpLoad %ulong %4 - %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 - %11 = OpLoad %v4uchar %49 Aligned 4 - %50 = OpCompositeExtract %uchar %11 0 - %51 = OpCompositeExtract %uchar %11 1 - %52 = OpCompositeExtract %uchar %11 2 - %53 = OpCompositeExtract %uchar %11 3 - %73 = OpBitcast %uchar %50 - %22 = OpUConvert %ushort %73 - %74 = OpBitcast %uchar %51 - %23 = OpUConvert %ushort %74 - %75 = OpBitcast %uchar %52 - %24 = OpUConvert %ushort %75 - %76 = OpBitcast %uchar %53 - %25 = OpUConvert %ushort %76 - OpStore %6 %22 - OpStore %7 %23 - OpStore %8 %24 - OpStore %9 %25 - %26 = OpLoad %ushort %7 - %27 = OpLoad %ushort %8 - %28 = OpLoad %ushort %9 - %29 = OpLoad %ushort %6 - %77 = OpUndef %v4ushort - %78 = OpCompositeInsert %v4ushort %26 %77 0 - %79 = OpCompositeInsert %v4ushort %27 %78 1 - %80 = OpCompositeInsert %v4ushort %28 %79 2 - %81 = OpCompositeInsert %v4ushort %29 %80 3 - %12 = OpCopyObject %v4ushort %81 - %30 = OpCopyObject %v4ushort %12 - OpStore %10 %30 - %31 = OpLoad %v4ushort %10 - %13 = OpCopyObject %v4ushort %31 - %32 = OpCompositeExtract %ushort %13 0 - %33 = OpCompositeExtract %ushort %13 1 - %34 = OpCompositeExtract %ushort %13 2 - %35 = OpCompositeExtract %ushort %13 3 - OpStore %8 %32 - OpStore %9 %33 - OpStore %6 %34 - OpStore %7 %35 - %36 = OpLoad %ushort %8 - %37 = OpLoad %ushort %9 - %38 = OpLoad %ushort %6 - %39 = OpLoad %ushort %7 - %82 = OpUndef %v4ushort - %83 = OpCompositeInsert %v4ushort %36 %82 0 - %84 = OpCompositeInsert %v4ushort %37 %83 1 - %85 = OpCompositeInsert %v4ushort %38 %84 2 - %86 = OpCompositeInsert %v4ushort %39 %85 3 - %15 = OpCopyObject %v4ushort %86 - %14 = OpCopyObject %v4ushort %15 - %40 = OpCompositeExtract %ushort %14 0 - %41 = OpCompositeExtract %ushort %14 1 - %42 = OpCompositeExtract %ushort %14 2 - %43 = OpCompositeExtract %ushort %14 3 - OpStore %9 %40 - OpStore %6 %41 - OpStore %7 %42 - OpStore %8 %43 - %44 = OpLoad %ushort %6 - %45 = OpLoad %ushort %7 - %46 = OpLoad %ushort %8 - %47 = OpLoad %ushort %9 - %87 = OpBitcast %ushort %44 - %54 = OpUConvert %uchar %87 - %88 = OpBitcast %ushort %45 - %55 = OpUConvert %uchar %88 - %89 = OpBitcast %ushort %46 - %56 = OpUConvert %uchar %89 - %90 = OpBitcast %ushort %47 - %57 = OpUConvert %uchar %90 - %91 = OpUndef %v4uchar - %92 = OpCompositeInsert %v4uchar %54 %91 0 - %93 = OpCompositeInsert %v4uchar %55 %92 1 - %94 = OpCompositeInsert %v4uchar %56 %93 2 - %95 = OpCompositeInsert %v4uchar %57 %94 3 - %16 = OpCopyObject %v4uchar %95 - %48 = OpLoad %ulong %5 - %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48 - OpStore %58 %16 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt deleted file mode 100644 index c3a1f6f..0000000 --- a/ptx/src/test/spirv_run/xor.spvtxt +++ /dev/null @@ -1,59 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %28 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "xor" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %31 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %uchar = OpTypeInt 8 0 -%_ptr_Generic_uchar = OpTypePointer Generic %uchar - %1 = OpFunction %void None %31 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %26 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %10 - %11 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %23 Aligned 4 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %15 - %38 = OpBitcast %_ptr_Generic_uchar %24 - %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 - %22 = OpBitcast %_ptr_Generic_uint %39 - %14 = OpLoad %uint %22 Aligned 4 - OpStore %7 %14 - %17 = OpLoad %uint %6 - %18 = OpLoad %uint %7 - %16 = OpBitwiseXor %uint %17 %18 - OpStore %6 %16 - %19 = OpLoad %ulong %5 - %20 = OpLoad %uint %6 - %25 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %25 %20 Aligned 4 - OpReturn - OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs deleted file mode 100644 index 9b422fd..0000000 --- a/ptx/src/translate.rs +++ /dev/null @@ -1,8181 +0,0 @@ -use crate::ast; -use half::f16; -use rspirv::dr; -use std::cell::RefCell; -use std::collections::{hash_map, BTreeMap, HashMap, HashSet}; -use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; - -use rspirv::binary::{Assemble, Disassemble}; - -static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); -static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; - -quick_error! { - #[derive(Debug)] - pub enum TranslateError { - UnknownSymbol {} - UntypedSymbol {} - MismatchedType {} - Spirv(err: rspirv::dr::Error) { - from() - display("{}", err) - cause(err) - } - Unreachable {} - Todo {} - } -} - -#[cfg(debug_assertions)] -fn error_unreachable() -> TranslateError { - unreachable!() -} - -#[cfg(not(debug_assertions))] -fn error_unreachable() -> TranslateError { - TranslateError::Unreachable -} - -fn error_unknown_symbol() -> TranslateError { - TranslateError::UnknownSymbol -} - -#[derive(PartialEq, Eq, Hash, Clone)] -enum SpirvType { - Base(SpirvScalarKey), - Vector(SpirvScalarKey, u8), - Array(SpirvScalarKey, Vec), - Pointer(Box, spirv::StorageClass), - Func(Option>, Vec), - Struct(Vec), -} - -impl SpirvType { - fn new(t: ast::Type) -> Self { - match t { - ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), - ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( - Box::new(SpirvType::Base(pointer_t.into())), - space.to_spirv(), - ), - } - } - - fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { - let key = Self::new(t); - SpirvType::Pointer(Box::new(key), outer_space) - } -} - -impl From for SpirvType { - fn from(t: ast::ScalarType) -> Self { - SpirvType::Base(t.into()) - } -} - -struct TypeWordMap { - void: spirv::Word, - complex: HashMap, - constants: HashMap<(SpirvType, u64), spirv::Word>, -} - -// SPIR-V integer type definitions are signless, more below: -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -enum SpirvScalarKey { - B8, - B16, - B32, - B64, - F16, - F32, - F64, - Pred, - F16x2, -} - -impl From for SpirvScalarKey { - fn from(t: ast::ScalarType) -> Self { - match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - SpirvScalarKey::B16 - } - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { - SpirvScalarKey::B32 - } - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { - SpirvScalarKey::B64 - } - ast::ScalarType::F16 => SpirvScalarKey::F16, - ast::ScalarType::F32 => SpirvScalarKey::F32, - ast::ScalarType::F64 => SpirvScalarKey::F64, - ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, - ast::ScalarType::Pred => SpirvScalarKey::Pred, - } - } -} - -impl TypeWordMap { - fn new(b: &mut dr::Builder) -> TypeWordMap { - let void = b.type_void(None); - TypeWordMap { - void: void, - complex: HashMap::::new(), - constants: HashMap::new(), - } - } - - fn void(&self) -> spirv::Word { - self.void - } - - fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { - let key: SpirvScalarKey = t.into(); - self.get_or_add_spirv_scalar(b, key) - } - - fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word { - *self - .complex - .entry(SpirvType::Base(key)) - .or_insert_with(|| match key { - SpirvScalarKey::B8 => b.type_int(None, 8, 0), - SpirvScalarKey::B16 => b.type_int(None, 16, 0), - SpirvScalarKey::B32 => b.type_int(None, 32, 0), - SpirvScalarKey::B64 => b.type_int(None, 64, 0), - SpirvScalarKey::F16 => b.type_float(None, 16), - SpirvScalarKey::F32 => b.type_float(None, 32), - SpirvScalarKey::F64 => b.type_float(None, 64), - SpirvScalarKey::Pred => b.type_bool(None), - SpirvScalarKey::F16x2 => todo!(), - }) - } - - fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { - match t { - SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), - SpirvType::Pointer(ref typ, storage) => { - let base = self.get_or_add(b, *typ.clone()); - *self - .complex - .entry(t) - .or_insert_with(|| b.type_pointer(None, storage, base)) - } - SpirvType::Vector(typ, len) => { - let base = self.get_or_add_spirv_scalar(b, typ); - *self - .complex - .entry(t) - .or_insert_with(|| b.type_vector(None, base, len as u32)) - } - SpirvType::Array(typ, array_dimensions) => { - let (base_type, length) = match &*array_dimensions { - &[] => { - return self.get_or_add(b, SpirvType::Base(typ)); - } - &[len] => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self.get_or_add_spirv_scalar(b, typ); - let len_const = b.constant_u32(u32_type, None, len); - (base, len_const) - } - array_dimensions => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self - .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); - let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); - (base, len_const) - } - }; - *self - .complex - .entry(SpirvType::Array(typ, array_dimensions)) - .or_insert_with(|| b.type_array(None, base_type, length)) - } - SpirvType::Func(ref out_params, ref in_params) => { - let out_t = match out_params { - Some(p) => self.get_or_add(b, *p.clone()), - None => self.void(), - }; - let in_t = in_params - .iter() - .map(|t| self.get_or_add(b, t.clone())) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| b.type_function(None, out_t, in_t)) - } - SpirvType::Struct(ref underlying) => { - let underlying_ids = underlying - .iter() - .map(|t| self.get_or_add_spirv_scalar(b, *t)) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| b.type_struct(None, underlying_ids)) - } - } - } - - fn get_or_add_fn( - &mut self, - b: &mut dr::Builder, - in_params: impl Iterator, - mut out_params: impl ExactSizeIterator, - ) -> (spirv::Word, spirv::Word) { - let (out_args, out_spirv_type) = if out_params.len() == 0 { - (None, self.void()) - } else if out_params.len() == 1 { - let arg_as_key = out_params.next().unwrap(); - ( - Some(Box::new(arg_as_key.clone())), - self.get_or_add(b, arg_as_key), - ) - } else { - // TODO: support multiple return values - todo!() - }; - ( - out_spirv_type, - self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), - ) - } - - fn get_or_add_constant( - &mut self, - b: &mut dr::Builder, - typ: &ast::Type, - init: &[u8], - ) -> Result { - Ok(match typ { - ast::Type::Scalar(t) => match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v), - ), - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v, - |b, result_type, v| b.constant_u64(result_type, None, v), - ), - ast::ScalarType::F16 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u16>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), - ), - ast::ScalarType::F32 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u32>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v), - ), - ast::ScalarType::F64 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u64>(v) }, - |b, result_type, v| b.constant_f64(result_type, None, v), - ), - ast::ScalarType::F16x2 => return Err(TranslateError::Todo), - ast::ScalarType::Pred => self.get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| { - if v == 0 { - b.constant_false(result_type, None) - } else { - b.constant_true(result_type, None) - } - }, - ), - }, - ast::Type::Vector(typ, len) => { - let result_type = - self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); - let size_of_t = typ.size_of(); - let components = (0..*len) - .map(|x| { - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - ) - }) - .collect::, _>>()?; - b.constant_composite(result_type, None, components.into_iter()) - } - ast::Type::Array(typ, dims) => match dims.as_slice() { - [] => return Err(error_unreachable()), - [dim] => { - let result_type = self - .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); - let size_of_t = typ.size_of(); - let components = (0..*dim) - .map(|x| { - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - ) - }) - .collect::, _>>()?; - b.constant_composite(result_type, None, components.into_iter()) - } - [first_dim, rest @ ..] => { - let result_type = self.get_or_add( - b, - SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), - ); - let size_of_t = rest - .iter() - .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); - let components = (0..*first_dim) - .map(|x| { - self.get_or_add_constant( - b, - &ast::Type::Array(*typ, rest.to_vec()), - &init[((size_of_t as usize) * (x as usize))..], - ) - }) - .collect::, _>>()?; - b.constant_composite(result_type, None, components.into_iter()) - } - }, - ast::Type::Pointer(..) => return Err(error_unreachable()), - }) - } - - fn get_or_add_constant_single< - T: Copy, - CastAsU64: FnOnce(T) -> u64, - InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, - >( - &mut self, - b: &mut dr::Builder, - key: ast::ScalarType, - init: &[u8], - cast: CastAsU64, - f: InsertConstant, - ) -> spirv::Word { - let value = unsafe { *(init.as_ptr() as *const T) }; - let value_64 = cast(value); - let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); - match self.constants.get(&ht_key) { - Some(value) => *value, - None => { - let spirv_type = self.get_or_add_scalar(b, key); - let result = f(b, spirv_type, value); - self.constants.insert(ht_key, result); - result - } - } - } -} - -pub struct Module { - pub spirv: dr::Module, - pub kernel_info: HashMap, - pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>, - pub build_options: CString, -} -impl Module { - pub fn assemble(&self) -> Vec { - self.spirv.assemble() - } -} - -pub struct KernelInfo { - pub arguments_sizes: Vec<(usize, bool)>, - pub uses_shared_mem: bool, -} - -pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result { - let mut id_defs = GlobalStringIdResolver::<'input>::new(1); - let mut ptx_impl_imports = HashMap::new(); - let directives = ast - .directives - .into_iter() - .filter_map(|directive| { - translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() - }) - .collect::, _>>()?; - let directives = hoist_function_globals(directives); - let must_link_ptx_impl = ptx_impl_imports.len() > 0; - let mut directives = ptx_impl_imports - .into_iter() - .map(|(_, v)| v) - .chain(directives.into_iter()) - .collect::>(); - let mut builder = dr::Builder::new(); - builder.reserve_ids(id_defs.current_id()); - let call_map = MethodsCallMap::new(&directives); - let mut directives = - convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); - normalize_variable_decls(&mut directives); - let denorm_information = compute_denorm_information(&directives); - // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module - builder.set_version(1, 3); - emit_capabilities(&mut builder); - emit_extensions(&mut builder); - let opencl_id = emit_opencl_import(&mut builder); - emit_memory_model(&mut builder); - let mut map = TypeWordMap::new(&mut builder); - //emit_builtins(&mut builder, &mut map, &id_defs); - let mut kernel_info = HashMap::new(); - let (build_options, should_flush_denorms) = - emit_denorm_build_string(&call_map, &denorm_information); - let (directives, globals_use_map) = get_globals_use_map(directives); - emit_directives( - &mut builder, - &mut map, - &id_defs, - opencl_id, - should_flush_denorms, - &call_map, - globals_use_map, - directives, - &mut kernel_info, - )?; - let spirv = builder.module(); - Ok(Module { - spirv, - kernel_info, - should_link_ptx_impl: if must_link_ptx_impl { - Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD)) - } else { - None - }, - build_options, - }) -} - -fn get_globals_use_map<'input>( - directives: Vec>, -) -> ( - Vec>, - HashMap, HashSet>, -) { - let mut known_globals = HashSet::new(); - for directive in directives.iter() { - match directive { - Directive::Variable(_, ast::Variable { name, .. }) => { - known_globals.insert(*name); - } - Directive::Method(..) => {} - } - } - let mut symbol_uses_map = HashMap::new(); - let directives = directives - .into_iter() - .map(|directive| match directive { - Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, - Directive::Method(Function { - func_decl, - body: Some(mut statements), - globals, - import_as, - tuning, - linkage, - }) => { - let method_name = func_decl.borrow().name; - statements = statements - .into_iter() - .map(|statement| { - statement.map_id(&mut |symbol, _| { - if known_globals.contains(&symbol) { - multi_hash_map_append(&mut symbol_uses_map, method_name, symbol); - } - symbol - }) - }) - .collect::>(); - Directive::Method(Function { - func_decl, - body: Some(statements), - globals, - import_as, - tuning, - linkage, - }) - } - }) - .collect::>(); - (directives, symbol_uses_map) -} - -fn hoist_function_globals(directives: Vec) -> Vec { - let mut result = Vec::with_capacity(directives.len()); - for directive in directives { - match directive { - Directive::Method(method) => { - for variable in method.globals { - result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); - } - result.push(Directive::Method(Function { - globals: Vec::new(), - ..method - })) - } - _ => result.push(directive), - } - } - result -} - -// TODO: remove this once we have pef-function support for denorms -fn emit_denorm_build_string<'input>( - call_map: &MethodsCallMap, - denorm_information: &HashMap< - ast::MethodName<'input, spirv::Word>, - HashMap, - >, -) -> (CString, bool) { - let denorm_counts = denorm_information - .iter() - .map(|(method, meth_denorm)| { - let f16_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - let f32_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - (method, (f16_count + f32_count)) - }) - .collect::>(); - let mut flush_over_preserve = 0; - for (kernel, children) in call_map.kernels() { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Kernel(kernel)) - .unwrap_or(&0); - for child_fn in children { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Func(*child_fn)) - .unwrap_or(&0); - } - } - if flush_over_preserve > 0 { - ( - CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), - true, - ) - } else { - (CString::new("-ze-take-global-address").unwrap(), false) - } -} - -fn emit_directives<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl_id: spirv::Word, - should_flush_denorms: bool, - call_map: &MethodsCallMap<'input>, - globals_use_map: HashMap, HashSet>, - directives: Vec>, - kernel_info: &mut HashMap, -) -> Result<(), TranslateError> { - let empty_body = Vec::new(); - for d in directives.iter() { - match d { - Directive::Variable(linking, var) => { - emit_variable(builder, map, id_defs, *linking, &var)?; - } - Directive::Method(f) => { - let f_body = match &f.body { - Some(f) => f, - None => { - if f.linkage.contains(ast::LinkingDirective::EXTERN) { - &empty_body - } else { - continue; - } - } - }; - for var in f.globals.iter() { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - let func_decl = (*f.func_decl).borrow(); - let fn_id = emit_function_header( - builder, - map, - &id_defs, - &*func_decl, - call_map, - &globals_use_map, - kernel_info, - )?; - if func_decl.name.is_kernel() { - if should_flush_denorms { - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::DenormFlushToZero, - [16], - ); - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::DenormFlushToZero, - [32], - ); - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::DenormFlushToZero, - [64], - ); - } - // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) - builder.execution_mode(fn_id, spirv_headers::ExecutionMode::ContractionOff, []); - for t in f.tuning.iter() { - match *t { - ast::TuningDirective::MaxNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, - [nx, ny, nz], - ); - } - ast::TuningDirective::ReqNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id, - spirv_headers::ExecutionMode::LocalSize, - [nx, ny, nz], - ); - } - // Too architecture specific - ast::TuningDirective::MaxNReg(..) - | ast::TuningDirective::MinNCtaPerSm(..) => {} - } - } - } - emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; - emit_function_linkage(builder, id_defs, f, fn_id)?; - builder.select_block(None)?; - builder.end_function()?; - } - } - } - Ok(()) -} - -fn emit_function_linkage<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - f: &Function, - fn_name: spirv::Word, -) -> Result<(), TranslateError> { - if f.linkage == ast::LinkingDirective::NONE { - return Ok(()); - }; - let linking_name = match f.func_decl.borrow().name { - // According to SPIR-V rules linkage attributes are invalid on kernels - ast::MethodName::Kernel(..) => return Ok(()), - ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( - || match id_defs.reverse_variables.get(&fn_id) { - Some(fn_name) => Ok(fn_name), - None => Err(error_unknown_symbol()), - }, - Result::Ok, - )?, - }; - emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); - Ok(()) -} - -struct MethodsCallMap<'input> { - map: HashMap, HashSet>, -} - -impl<'input> MethodsCallMap<'input> { - fn new(module: &[Directive<'input>]) -> Self { - let mut directly_called_by = HashMap::new(); - for directive in module { - match directive { - Directive::Method(Function { - func_decl, - body: Some(statements), - .. - }) => { - let call_key: ast::MethodName<_> = (**func_decl).borrow().name; - if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { - entry.insert(Vec::new()); - } - for statement in statements { - match statement { - Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call_key, call.name); - } - _ => {} - } - } - } - _ => {} - } - } - let mut result = HashMap::new(); - for (&method_key, children) in directly_called_by.iter() { - let mut visited = HashSet::new(); - for child in children { - Self::add_call_map_single(&directly_called_by, &mut visited, *child); - } - result.insert(method_key, visited); - } - MethodsCallMap { map: result } - } - - fn add_call_map_single( - directly_called_by: &HashMap, Vec>, - visited: &mut HashSet, - current: spirv::Word, - ) { - if !visited.insert(current) { - return; - } - if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { - for child in children { - Self::add_call_map_single(directly_called_by, visited, *child); - } - } - } - - fn get_kernel_children(&self, name: &'input str) -> impl Iterator { - self.map - .get(&ast::MethodName::Kernel(name)) - .into_iter() - .flatten() - } - - fn kernels(&self) -> impl Iterator)> { - self.map - .iter() - .filter_map(|(method, children)| match method { - ast::MethodName::Kernel(kernel) => Some((*kernel, children)), - ast::MethodName::Func(..) => None, - }) - } - - fn methods( - &self, - ) -> impl Iterator, &HashSet)> { - self.map - .iter() - .map(|(method, children)| (*method, children)) - } - - fn visit_callees( - &self, - method: ast::MethodName<'input, spirv::Word>, - f: impl FnMut(spirv::Word), - ) { - self.map - .get(&method) - .into_iter() - .flatten() - .copied() - .for_each(f); - } -} - -fn multi_hash_map_append< - K: Eq + std::hash::Hash, - V, - Collection: std::iter::Extend + std::default::Default, ->( - m: &mut HashMap, - key: K, - value: V, -) { - match m.entry(key) { - hash_map::Entry::Occupied(mut entry) => { - entry.get_mut().extend(iter::once(value)); - } - hash_map::Entry::Vacant(entry) => { - entry.insert(Default::default()).extend(iter::once(value)); - } - } -} - -/* - PTX represents dynamically allocated shared local memory as - .extern .shared .b32 shared_mem[]; - In SPIRV/OpenCL world this is expressed as an additional argument to the kernel - And in AMD compilation - This pass looks for all uses of .extern .shared and converts them to - an additional method argument - The question is how this artificial argument should be expressed. There are - several options: - * Straight conversion: - .shared .b32 shared_mem[] - * Introduce .param_shared statespace: - .param_shared .b32 shared_mem - or - .param_shared .b32 shared_mem[] - * Introduce .shared_ptr type: - .param .shared_ptr .b32 shared_mem - * Reuse .ptr hint: - .param .u64 .ptr shared_mem - This is the most tempting, but also the most nonsensical, .ptr is just a - hint, which has no semantical meaning (and the output of our - transformation has a semantical meaning - we emit additional - "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") -*/ -fn convert_dynamic_shared_memory_usage<'input>( - module: Vec>, - kernels_methods_call_map: &MethodsCallMap<'input>, - new_id: &mut impl FnMut() -> spirv::Word, -) -> Vec> { - let mut globals_shared = HashMap::new(); - for dir in module.iter() { - match dir { - Directive::Variable( - _, - ast::Variable { - state_space: ast::StateSpace::Shared, - name, - v_type, - .. - }, - ) => { - globals_shared.insert(*name, v_type.clone()); - } - _ => {} - } - } - if globals_shared.len() == 0 { - return module; - } - let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); - let module = module - .into_iter() - .map(|directive| match directive { - Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - }) => { - let call_key = (*func_decl).borrow().name; - let statements = statements - .into_iter() - .map(|statement| { - statement.map_id(&mut |id, _| { - if let Some(type_) = globals_shared.get(&id) { - methods_to_directly_used_shared_globals - .entry(call_key) - .or_insert_with(HashSet::new) - .insert(id); - } - id - }) - }) - .collect(); - Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - }) - } - directive => directive, - }) - .collect::>(); - // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, - // make sure it gets propagated to `fn1` and `kernel` - let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( - methods_to_directly_used_shared_globals, - kernels_methods_call_map, - ); - // now visit every method declaration and inject those additional arguments - let mut directives = Vec::with_capacity(module.len()); - for directive in module.into_iter() { - match directive { - Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - }) => { - let statements = { - let func_decl_ref = &mut (*func_decl).borrow_mut(); - let method_name = func_decl_ref.name; - insert_arguments_remap_statements( - new_id, - kernels_methods_call_map, - &globals_shared, - &methods_to_indirectly_used_shared_globals, - method_name, - &mut directives, - func_decl_ref, - statements, - ) - }; - directives.push(Directive::Method(Function { - func_decl, - globals, - body: Some(statements), - import_as, - tuning, - linkage, - })); - } - directive => directives.push(directive), - } - } - directives -} - -fn insert_arguments_remap_statements<'input>( - new_id: &mut impl FnMut() -> u32, - kernels_methods_call_map: &MethodsCallMap<'input>, - globals_shared: &HashMap, - methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, spirv::Word>, - BTreeSet, - >, - method_name: ast::MethodName, - result: &mut Vec, - func_decl_ref: &mut std::cell::RefMut>, - statements: Vec, ExpandedArgParams>>, -) -> Vec, ExpandedArgParams>> { - let remapped_globals_in_method = - if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { - match method_name { - ast::MethodName::Func(..) => { - let remapped_globals = method_globals - .iter() - .map(|global| { - ( - *global, - ( - new_id(), - globals_shared - .get(&global) - .unwrap_or_else(|| todo!()) - .clone(), - ), - ) - }) - .collect::>(); - for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { - func_decl_ref.input_arguments.push(ast::Variable { - align: None, - v_type: shared_global_type.clone(), - state_space: ast::StateSpace::Shared, - name: *new_shared_global_id, - array_init: Vec::new(), - }); - } - remapped_globals - } - ast::MethodName::Kernel(..) => method_globals - .iter() - .map(|global| { - ( - *global, - ( - *global, - globals_shared - .get(&global) - .unwrap_or_else(|| todo!()) - .clone(), - ), - ) - }) - .collect::>(), - } - } else { - return statements; - }; - replace_uses_of_shared_memory( - new_id, - methods_to_indirectly_used_shared_globals, - statements, - remapped_globals_in_method, - ) -} - -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -enum GlobalSharedSize { - ExternUnsized, - Sized(usize), -} - -impl GlobalSharedSize { - fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize { - match (self, other) { - (GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => { - GlobalSharedSize::Sized(usize::max(s1, s2)) - } - _ => GlobalSharedSize::ExternUnsized, - } - } -} - -fn replace_uses_of_shared_memory<'input>( - new_id: &mut impl FnMut() -> spirv::Word, - methods_to_indirectly_used_shared_globals: &HashMap< - ast::MethodName<'input, spirv::Word>, - BTreeSet, - >, - statements: Vec, - remapped_globals_in_method: BTreeMap, -) -> Vec { - let mut result = Vec::with_capacity(statements.len()); - for statement in statements { - match statement { - Statement::Call(mut call) => { - // We can safely skip checking call arguments, - // because there's simply no way to pass shared ptr - // without converting it to .b64 first - if let Some(shared_globals_used_by_callee) = - methods_to_indirectly_used_shared_globals.get(&ast::MethodName::Func(call.name)) - { - for &shared_global_used_by_callee in shared_globals_used_by_callee { - let (remapped_shared_id, type_) = remapped_globals_in_method - .get(&shared_global_used_by_callee) - .unwrap_or_else(|| todo!()); - call.input_arguments.push(( - *remapped_shared_id, - type_.clone(), - ast::StateSpace::Shared, - )); - } - } - result.push(Statement::Call(call)) - } - statement => { - let new_statement = statement.map_id(&mut |id, _| { - if let Some((remapped_shared_id, _)) = remapped_globals_in_method.get(&id) { - *remapped_shared_id - } else { - id - } - }); - result.push(new_statement); - } - } - } - result -} - -// We need to compute two kinds of information: -// * If it's a kernel -> size of .shared globals in use (direct or indirect) -// * If it's a function -> does it use .shared global (directly or indirectly) -fn resolve_indirect_uses_of_globals_shared<'input>( - methods_use_of_globals_shared: HashMap< - ast::MethodName<'input, spirv::Word>, - HashSet, - >, - kernels_methods_call_map: &MethodsCallMap<'input>, -) -> HashMap, BTreeSet> { - let mut result = HashMap::new(); - for (method, callees) in kernels_methods_call_map.methods() { - let mut indirect_globals = methods_use_of_globals_shared - .get(&method) - .into_iter() - .flatten() - .copied() - .collect::>(); - for &callee in callees { - indirect_globals.extend( - methods_use_of_globals_shared - .get(&ast::MethodName::Func(callee)) - .into_iter() - .flatten() - .copied(), - ); - } - result.insert(method, indirect_globals); - } - result -} - -type DenormCountMap = HashMap; - -fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { - let num_value = if value { 1 } else { -1 }; - denorm_count_map_update_impl(map, key, num_value); -} - -fn denorm_count_map_update_impl( - map: &mut DenormCountMap, - key: T, - num_value: isize, -) { - match map.entry(key) { - hash_map::Entry::Occupied(mut counter) => { - *(counter.get_mut()) += num_value; - } - hash_map::Entry::Vacant(entry) => { - entry.insert(num_value); - } - } -} - -// HACK ALERT! -// This function is a "good enough" heuristic of whetever to mark f16/f32 operations -// in the kernel as flushing denorms to zero or preserving them -// PTX support per-instruction ftz information. Unfortunately SPIR-V has no -// such capability, so instead we guesstimate which use is more common in the kernel -// and emit suitable execution mode -fn compute_denorm_information<'input>( - module: &[Directive<'input>], -) -> HashMap, HashMap> { - let mut denorm_methods = HashMap::new(); - for directive in module { - match directive { - Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} - Directive::Method(Function { - func_decl, - body: Some(statements), - .. - }) => { - let mut flush_counter = DenormCountMap::new(); - let method_key = (**func_decl).borrow().name; - for statement in statements { - match statement { - Statement::Instruction(inst) => { - if let Some((flush, width)) = inst.flush_to_zero() { - denorm_count_map_update(&mut flush_counter, width, flush); - } - } - Statement::LoadVar(..) => {} - Statement::StoreVar(..) => {} - Statement::Call(_) => {} - Statement::Conditional(_) => {} - Statement::Conversion(_) => {} - Statement::Constant(_) => {} - Statement::RetValue(_, _) => {} - Statement::Label(_) => {} - Statement::Variable(_) => {} - Statement::PtrAccess { .. } => {} - Statement::RepackVector(_) => {} - Statement::FunctionPointer(_) => {} - } - } - denorm_methods.insert(method_key, flush_counter); - } - } - } - denorm_methods - .into_iter() - .map(|(name, v)| { - let width_to_denorm = v - .into_iter() - .map(|(k, flush_over_preserve)| { - let mode = if flush_over_preserve > 0 { - spirv::FPDenormMode::FlushToZero - } else { - spirv::FPDenormMode::Preserve - }; - (k, (mode, flush_over_preserve)) - }) - .collect(); - (name, width_to_denorm) - }) - .collect() -} - -fn emit_function_header<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - defined_globals: &GlobalStringIdResolver<'input>, - func_decl: &ast::MethodDeclaration<'input, spirv::Word>, - call_map: &MethodsCallMap<'input>, - globals_use_map: &HashMap, HashSet>, - kernel_info: &mut HashMap, -) -> Result { - if let ast::MethodName::Kernel(name) = func_decl.name { - let args_lens = func_decl - .input_arguments - .iter() - .map(|param| { - ( - param.v_type.size_of(), - matches!(param.v_type, ast::Type::Pointer(..)), - ) - }) - .collect(); - kernel_info.insert( - name.to_string(), - KernelInfo { - arguments_sizes: args_lens, - uses_shared_mem: func_decl.shared_mem.is_some(), - }, - ); - } - let (ret_type, func_type) = get_function_type( - builder, - map, - func_decl.effective_input_arguments().map(|(_, typ)| typ), - &func_decl.return_arguments, - ); - let fn_id = match func_decl.name { - ast::MethodName::Kernel(name) => { - let fn_id = defined_globals.get_id(name)?; - let interface = globals_use_map - .get(&ast::MethodName::Kernel(name)) - .into_iter() - .flatten() - .copied() - .chain({ - call_map - .get_kernel_children(name) - .copied() - .flat_map(|subfunction| { - globals_use_map - .get(&ast::MethodName::Func(subfunction)) - .into_iter() - .flatten() - .copied() - }) - .into_iter() - }) - .collect::>(); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface); - fn_id - } - ast::MethodName::Func(name) => name, - }; - builder.begin_function( - ret_type, - Some(fn_id), - spirv::FunctionControl::NONE, - func_type, - )?; - for (name, typ) in func_decl.effective_input_arguments() { - let result_type = map.get_or_add(builder, typ); - builder.function_parameter(Some(name), result_type)?; - } - Ok(fn_id) -} - -fn emit_capabilities(builder: &mut dr::Builder) { - builder.capability(spirv::Capability::GenericPointer); - builder.capability(spirv::Capability::Linkage); - builder.capability(spirv::Capability::Addresses); - builder.capability(spirv::Capability::Kernel); - builder.capability(spirv::Capability::Int8); - builder.capability(spirv::Capability::Int16); - builder.capability(spirv::Capability::Int64); - builder.capability(spirv::Capability::Float16); - builder.capability(spirv::Capability::Float64); - builder.capability(spirv::Capability::DenormFlushToZero); - // TODO: re-enable when Intel float control extension works - //builder.capability(spirv::Capability::FunctionFloatControlINTEL); -} - -// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html -fn emit_extensions(builder: &mut dr::Builder) { - // TODO: re-enable when Intel float control extension works - //builder.extension("SPV_INTEL_float_controls2"); - builder.extension("SPV_KHR_float_controls"); - builder.extension("SPV_KHR_no_integer_wrap_decoration"); -} - -fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { - builder.ext_inst_import("OpenCL.std") -} - -fn emit_memory_model(builder: &mut dr::Builder) { - builder.memory_model( - spirv::AddressingModel::Physical64, - spirv::MemoryModel::OpenCL, - ); -} - -fn translate_directive<'input, 'a>( - id_defs: &'a mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &'a mut HashMap>, - d: ast::Directive<'input, ast::ParsedArgParams<'input>>, -) -> Result>, TranslateError> { - Ok(match d { - ast::Directive::Variable(linking, var) => Some(Directive::Variable( - linking, - ast::Variable { - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), - array_init: var.array_init, - }, - )), - ast::Directive::Method(linkage, f) => { - translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method) - } - }) -} - -fn translate_function<'input, 'a>( - id_defs: &'a mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &'a mut HashMap>, - linkage: ast::LinkingDirective, - f: ast::ParsedFunction<'input>, -) -> Result>, TranslateError> { - let import_as = match &f.func_directive { - ast::MethodDeclaration { - name: ast::MethodName::Func(func_name), - .. - } if *func_name == "__assertfail" || *func_name == "vprintf" => { - Some([ZLUDA_PTX_PREFIX, func_name].concat()) - } - _ => None, - }; - let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; - let mut func = to_ssa( - ptx_impl_imports, - str_resolver, - fn_resolver, - fn_decl, - f.body, - f.tuning, - linkage, - )?; - func.import_as = import_as; - if func.import_as.is_some() { - ptx_impl_imports.insert( - func.import_as.as_ref().unwrap().clone(), - Directive::Method(func), - ); - Ok(None) - } else { - Ok(Some(func)) - } -} - -fn rename_fn_params<'a, 'b>( - fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: &'b [ast::Variable<&'a str>], -) -> Vec> { - args.iter() - .map(|a| ast::Variable { - name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), - v_type: a.v_type.clone(), - state_space: a.state_space, - align: a.align, - array_init: a.array_init.clone(), - }) - .collect() -} - -fn to_ssa<'input, 'b>( - ptx_impl_imports: &'b mut HashMap>, - mut id_defs: FnStringIdResolver<'input, 'b>, - fn_defs: GlobalFnDeclResolver<'input, 'b>, - func_decl: Rc>>, - f_body: Option>>>, - tuning: Vec, - linkage: ast::LinkingDirective, -) -> Result, TranslateError> { - //deparamize_function_decl(&func_decl)?; - let f_body = match f_body { - Some(vec) => vec, - None => { - return Ok(Function { - func_decl: func_decl, - body: None, - globals: Vec::new(), - import_as: None, - tuning, - linkage, - }) - } - }; - let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; - let mut numeric_id_defs = id_defs.finish(); - let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; - let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; - let (func_decl, typed_statements) = - convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = insert_mem_ssa_statements( - typed_statements, - &mut numeric_id_defs, - &mut (*func_decl).borrow_mut(), - )?; - let mut numeric_id_defs = numeric_id_defs.finish(); - let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; - let expanded_statements = - insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; - let mut numeric_id_defs = numeric_id_defs.unmut(); - let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - let (f_body, globals) = - extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; - Ok(Function { - func_decl: func_decl, - globals: globals, - body: Some(f_body), - import_as: None, - tuning, - linkage, - }) -} - -fn fix_special_registers2<'a, 'b, 'input>( - ptx_impl_imports: &'a mut HashMap>, - typed_statements: Vec, - numeric_id_defs: &'a mut NumericIdResolver<'b>, -) -> Result, TranslateError> { - let result = Vec::with_capacity(typed_statements.len()); - let mut sreg_sresolver = SpecialRegisterResolver { - ptx_impl_imports, - numeric_id_defs, - result, - }; - for s in typed_statements { - match s { - Statement::Call(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - Statement::Instruction(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - Statement::Conditional(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - Statement::Conversion(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - Statement::PtrAccess(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - Statement::RepackVector(details) => { - let new_statement = details.visit(&mut sreg_sresolver)?; - sreg_sresolver.result.push(new_statement); - } - s @ Statement::Variable(_) - | s @ Statement::Label(_) - | s @ Statement::FunctionPointer(_) => sreg_sresolver.result.push(s), - _ => return Err(error_unreachable()), - } - } - Ok(sreg_sresolver.result) -} - -struct SpecialRegisterResolver<'a, 'b, 'input> { - ptx_impl_imports: &'a mut HashMap>, - numeric_id_defs: &'a mut NumericIdResolver<'b>, - result: Vec, -} - -impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { - fn replace_sreg( - &mut self, - desc: ArgumentDescriptor, - vector_index: Option, - ) -> Result { - if let Some(sreg) = self.numeric_id_defs.special_registers.get(desc.op) { - if desc.is_dst { - return Err(TranslateError::MismatchedType); - } - let input_arguments = match (vector_index, sreg.get_function_input_type()) { - (Some(idx), Some(inp_type)) => { - if inp_type != ast::ScalarType::U8 { - return Err(TranslateError::Unreachable); - } - let constant = self.numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(inp_type), - ast::StateSpace::Reg, - ))); - self.result.push(Statement::Constant(ConstantDefinition { - dst: constant, - typ: inp_type, - value: ast::ImmediateValue::U64(idx as u64), - })); - vec![( - TypedOperand::Reg(constant), - ast::Type::Scalar(inp_type), - ast::StateSpace::Reg, - )] - } - (None, None) => Vec::new(), - _ => return Err(TranslateError::MismatchedType), - }; - let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); - let return_type = sreg.get_function_return_type(); - let fn_result = self.numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(return_type), - ast::StateSpace::Reg, - ))); - let return_arguments = vec![( - fn_result, - ast::Type::Scalar(return_type), - ast::StateSpace::Reg, - )]; - let fn_call = register_external_fn_call( - self.numeric_id_defs, - self.ptx_impl_imports, - ocl_fn_name.to_string(), - return_arguments.iter().map(|(_, typ, space)| (typ, *space)), - input_arguments.iter().map(|(_, typ, space)| (typ, *space)), - )?; - self.result.push(Statement::Call(ResolvedCall { - uniform: false, - return_arguments, - name: fn_call, - input_arguments, - })); - Ok(fn_result) - } else { - Ok(desc.op) - } - } -} - -impl<'a, 'b, 'input> ArgumentMapVisitor - for SpecialRegisterResolver<'a, 'b, 'input> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.replace_sreg(desc, None) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(reg) => TypedOperand::Reg(self.replace_sreg(desc.new_op(reg), None)?), - op @ TypedOperand::RegOffset(_, _) => op, - op @ TypedOperand::Imm(_) => op, - TypedOperand::VecMember(reg, idx) => { - TypedOperand::VecMember(self.replace_sreg(desc.new_op(reg), Some(idx))?, idx) - } - }) - } -} - -fn extract_globals<'input, 'b>( - sorted_statements: Vec, - ptx_impl_imports: &mut HashMap, - id_def: &mut NumericIdResolver, -) -> Result<(Vec, Vec>), TranslateError> { - let mut local = Vec::with_capacity(sorted_statements.len()); - let mut global = Vec::new(); - for statement in sorted_statements { - match statement { - Statement::Variable( - var @ ast::Variable { - state_space: ast::StateSpace::Shared, - .. - }, - ) - | Statement::Variable( - var @ ast::Variable { - state_space: ast::StateSpace::Global, - .. - }, - ) => global.push(var), - Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Bfe { typ, arg }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Bfi { typ, arg }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Brev { typ, arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "brev_", typ.to_ptx_name()].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Brev { typ, arg }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Activemask { arg }) => { - let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Activemask { arg }, - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom( - details @ ast::AtomDetails { - inner: - ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Inc, - .. - }, - .. - }, - args, - )) => { - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - details.semantics.to_ptx_name(), - "_", - details.scope.to_ptx_name(), - "_", - details.space.to_ptx_name(), - "_inc", - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom(details, args), - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom( - details @ ast::AtomDetails { - inner: - ast::AtomInnerDetails::Unsigned { - op: ast::AtomUIntOp::Dec, - .. - }, - .. - }, - args, - )) => { - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - details.semantics.to_ptx_name(), - "_", - details.scope.to_ptx_name(), - "_", - details.space.to_ptx_name(), - "_dec", - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom(details, args), - fn_name, - )?); - } - Statement::Instruction(ast::Instruction::Atom( - details @ ast::AtomDetails { - inner: - ast::AtomInnerDetails::Float { - op: ast::AtomFloatOp::Add, - .. - }, - .. - }, - args, - )) => { - let fn_name = [ - ZLUDA_PTX_PREFIX, - "atom_", - details.semantics.to_ptx_name(), - "_", - details.scope.to_ptx_name(), - "_", - details.space.to_ptx_name(), - "_add_", - details.inner.get_type().to_ptx_name(), - ] - .concat(); - local.push(instruction_to_fn_call( - id_def, - ptx_impl_imports, - ast::Instruction::Atom(details, args), - fn_name, - )?); - } - s => local.push(s), - } - } - Ok((local, global)) -} - -impl ast::ScalarType { - fn to_ptx_name(self) -> &'static str { - match self { - ast::ScalarType::B8 => "b8", - ast::ScalarType::B16 => "b16", - ast::ScalarType::B32 => "b32", - ast::ScalarType::B64 => "b64", - ast::ScalarType::U8 => "u8", - ast::ScalarType::U16 => "u16", - ast::ScalarType::U32 => "u32", - ast::ScalarType::U64 => "u64", - ast::ScalarType::S8 => "s8", - ast::ScalarType::S16 => "s16", - ast::ScalarType::S32 => "s32", - ast::ScalarType::S64 => "s64", - ast::ScalarType::F16 => "f16", - ast::ScalarType::F32 => "f32", - ast::ScalarType::F64 => "f64", - ast::ScalarType::F16x2 => "f16x2", - ast::ScalarType::Pred => "pred", - } - } -} - -impl ast::AtomSemantics { - fn to_ptx_name(self) -> &'static str { - match self { - ast::AtomSemantics::Relaxed => "relaxed", - ast::AtomSemantics::Acquire => "acquire", - ast::AtomSemantics::Release => "release", - ast::AtomSemantics::AcquireRelease => "acq_rel", - } - } -} - -impl ast::MemScope { - fn to_ptx_name(self) -> &'static str { - match self { - ast::MemScope::Cta => "cta", - ast::MemScope::Gpu => "gpu", - ast::MemScope::Sys => "sys", - } - } -} - -impl ast::StateSpace { - fn to_ptx_name(self) -> &'static str { - match self { - ast::StateSpace::Generic => "generic", - ast::StateSpace::Global => "global", - ast::StateSpace::Shared => "shared", - ast::StateSpace::Reg => "reg", - ast::StateSpace::Const => "const", - ast::StateSpace::Local => "local", - ast::StateSpace::Param => "param", - ast::StateSpace::Sreg => "sreg", - } - } -} - -fn normalize_variable_decls(directives: &mut Vec) { - for directive in directives { - match directive { - Directive::Method(Function { - body: Some(func), .. - }) => { - func[1..].sort_by_key(|s| match s { - Statement::Variable(_) => 0, - _ => 1, - }); - } - _ => (), - } - } -} - -fn convert_to_typed_statements( - func: Vec, - fn_defs: &GlobalFnDeclResolver, - id_defs: &mut NumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::::with_capacity(func.len()); - for s in func { - match s { - Statement::Instruction(inst) => match inst { - ast::Instruction::Mov( - mov, - ast::Arg2Mov { - dst: ast::Operand::Reg(dst_reg), - src: ast::Operand::Reg(src_reg), - }, - ) if fn_defs.fns.contains_key(&src_reg) => { - if mov.typ != ast::Type::Scalar(ast::ScalarType::U64) { - return Err(TranslateError::MismatchedType); - } - result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { - dst: dst_reg, - src: src_reg, - })); - } - ast::Instruction::Call(call) => { - let resolver = fn_defs.get_fn_sig_resolver(call.func)?; - let resolved_call = resolver.resolve_in_spirv_repr(call)?; - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let reresolved_call = resolved_call.visit(&mut visitor)?; - visitor.func.push(reresolved_call); - visitor.func.extend(visitor.post_stmts); - } - inst => { - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction(inst.map(&mut visitor)?); - visitor.func.push(instruction); - visitor.func.extend(visitor.post_stmts); - } - }, - Statement::Label(i) => result.push(Statement::Label(i)), - Statement::Variable(v) => result.push(Statement::Variable(v)), - Statement::Conditional(c) => result.push(Statement::Conditional(c)), - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - -struct VectorRepackVisitor<'a, 'b> { - func: &'b mut Vec, - id_def: &'b mut NumericIdResolver<'a>, - post_stmts: Option, -} - -impl<'a, 'b> VectorRepackVisitor<'a, 'b> { - fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { - VectorRepackVisitor { - func, - id_def, - post_stmts: None, - } - } - - fn convert_vector( - &mut self, - is_dst: bool, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, - typ: &ast::Type, - state_space: ast::StateSpace, - idx: Vec, - ) -> Result { - // mov.u32 foobar, {a,b}; - let scalar_t = match typ { - ast::Type::Vector(scalar_t, _) => *scalar_t, - _ => return Err(TranslateError::MismatchedType), - }; - let temp_vec = self - .id_def - .register_intermediate(Some((typ.clone(), state_space))); - let statement = Statement::RepackVector(RepackVectorDetails { - is_extract: is_dst, - typ: scalar_t, - packed: temp_vec, - unpacked: idx, - non_default_implicit_conversion, - }); - if is_dst { - self.post_stmts = Some(statement); - } else { - self.func.push(statement); - } - Ok(temp_vec) - } -} - -impl<'a, 'b> ArgumentMapVisitor - for VectorRepackVisitor<'a, 'b> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - Ok(desc.op) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - ast::Operand::Reg(reg) => TypedOperand::Reg(reg), - ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), - ast::Operand::Imm(x) => TypedOperand::Imm(x), - ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( - desc.is_dst, - desc.non_default_implicit_conversion, - typ, - state_space, - vec, - )?), - }) - } -} - -fn instruction_to_fn_call( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - inst: ast::Instruction, - fn_name: String, -) -> Result { - let mut arguments = Vec::new(); - inst.visit(&mut |desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>| { - let (typ, space) = match typ { - Some((typ, space)) => (typ.clone(), space), - None => return Err(error_unreachable()), - }; - arguments.push((desc, typ, space)); - Ok(0) - })?; - let return_arguments_count = arguments - .iter() - .position(|(desc, _, _)| !desc.is_dst) - .unwrap_or(arguments.len()); - let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); - let fn_id = register_external_fn_call( - id_defs, - ptx_impl_imports, - fn_name, - return_arguments.iter().map(|(_, typ, state)| (typ, *state)), - input_arguments.iter().map(|(_, typ, state)| (typ, *state)), - )?; - Ok(Statement::Call(ResolvedCall { - uniform: false, - name: fn_id, - return_arguments: arguments_to_resolved_arguments(return_arguments), - input_arguments: arguments_to_resolved_arguments(input_arguments), - })) -} - -fn register_external_fn_call<'a>( - id_defs: &mut NumericIdResolver, - ptx_impl_imports: &mut HashMap, - name: String, - return_arguments: impl Iterator, - input_arguments: impl Iterator, -) -> Result { - match ptx_impl_imports.entry(name) { - hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.register_intermediate(None); - let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); - let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); - let func_decl = ast::MethodDeclaration:: { - return_arguments, - name: ast::MethodName::Func(fn_id), - input_arguments, - shared_mem: None, - }; - let func = Function { - func_decl: Rc::new(RefCell::new(func_decl)), - globals: Vec::new(), - body: None, - import_as: Some(entry.key().clone()), - tuning: Vec::new(), - linkage: ast::LinkingDirective::EXTERN, - }; - entry.insert(Directive::Method(func)); - Ok(fn_id) - } - hash_map::Entry::Occupied(entry) => match entry.get() { - Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { - ast::MethodName::Func(fn_id) => Ok(fn_id), - ast::MethodName::Kernel(_) => Err(error_unreachable()), - }, - _ => Err(error_unreachable()), - }, - } -} - -fn fn_arguments_to_variables<'a>( - id_defs: &mut NumericIdResolver, - args: impl Iterator, -) -> Vec> { - args.map(|(typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() -} - -fn arguments_to_resolved_arguments( - args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], -) -> Vec<(spirv::Word, ast::Type, ast::StateSpace)> { - args.iter() - .map(|(desc, typ, space)| (desc.op, typ.clone(), *space)) - .collect::>() -} - -fn normalize_labels( - func: Vec, - id_def: &mut NumericIdResolver, -) -> Vec { - let mut labels_in_use = HashSet::new(); - for s in func.iter() { - match s { - Statement::Instruction(i) => { - if let Some(target) = i.jump_target() { - labels_in_use.insert(target); - } - } - Statement::Conditional(cond) => { - labels_in_use.insert(cond.if_true); - labels_in_use.insert(cond.if_false); - } - Statement::Call(..) - | Statement::Variable(..) - | Statement::LoadVar(..) - | Statement::StoreVar(..) - | Statement::RetValue(..) - | Statement::Conversion(..) - | Statement::Constant(..) - | Statement::Label(..) - | Statement::PtrAccess { .. } - | Statement::RepackVector(..) - | Statement::FunctionPointer(..) => {} - } - } - iter::once(Statement::Label(id_def.register_intermediate(None))) - .chain(func.into_iter().filter(|s| match s { - Statement::Label(i) => labels_in_use.contains(i), - _ => true, - })) - .collect::>() -} - -fn normalize_predicates( - func: Vec, - id_def: &mut NumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func { - match s { - Statement::Label(id) => result.push(Statement::Label(id)), - Statement::Instruction((pred, inst)) => { - if let Some(pred) = pred { - let if_true = id_def.register_intermediate(None); - let if_false = id_def.register_intermediate(None); - let folded_bra = match &inst { - ast::Instruction::Bra(_, arg) => Some(arg.src), - _ => None, - }; - let mut branch = BrachCondition { - predicate: pred.label, - if_true: folded_bra.unwrap_or(if_true), - if_false, - }; - if pred.not { - std::mem::swap(&mut branch.if_true, &mut branch.if_false); - } - result.push(Statement::Conditional(branch)); - if folded_bra.is_none() { - result.push(Statement::Label(if_true)); - result.push(Statement::Instruction(inst)); - } - result.push(Statement::Label(if_false)); - } else { - result.push(Statement::Instruction(inst)); - } - } - Statement::Variable(var) => result.push(Statement::Variable(var)), - // Blocks are flattened when resolving ids - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - -/* - How do we handle arguments: - - input .params in kernels - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - We do this for two reasons. One, common treatment for argument-declared - .param variables and .param variables inside function (we assume that - at SPIR-V level every .param is a pointer in Function storage class) - - input .params in functions - .param .b64 in_arg - get turned into this SPIR-V: - %1 = OpFunctionParameter %_ptr_Function_ulong - - input .regs - .reg .b64 in_arg - get turned into the same SPIR-V as kernel .params: - %1 = OpFunctionParameter %ulong - %2 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %1 - - output .regs - .reg .b64 out_arg - get just a variable declaration: - %2 = OpVariable %%_ptr_Function_ulong Function - - output .params don't exist, they have been moved to input positions - by an earlier pass - Distinguishing betweem kernel .params and function .params is not the - cleanest solution. Alternatively, we could "deparamize" all kernel .param - arguments by turning them into .reg arguments like this: - .param .b64 arg -> .reg ptr<.b64,.param> arg - This has the massive downside that this transformation would have to run - very early and would muddy up already difficult code. It's simpler to just - have an if here -*/ -fn insert_mem_ssa_statements<'a, 'b>( - func: Vec, - id_def: &mut NumericIdResolver, - fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for arg in fn_decl.input_arguments.iter_mut() { - insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel()); - } - for arg in fn_decl.return_arguments.iter() { - insert_mem_ssa_argument_reg_return(&mut result, arg); - } - for s in func { - match s { - Statement::Call(call) => { - insert_mem_ssa_statement_default(id_def, &mut result, call.cast())? - } - Statement::Instruction(inst) => match inst { - ast::Instruction::Ret(d) => { - // TODO: handle multiple output args - match &fn_decl.return_arguments[..] { - [return_reg] => { - let new_id = id_def.register_intermediate(Some(( - return_reg.v_type.clone(), - ast::StateSpace::Reg, - ))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: ast::Arg2 { - dst: new_id, - src: return_reg.name, - }, - // TODO: ret with stateful conversion - state_space: ast::StateSpace::Reg, - typ: return_reg.v_type.clone(), - member_index: None, - })); - result.push(Statement::RetValue(d, new_id)); - } - [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))), - _ => unimplemented!(), - } - } - inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, - }, - Statement::Conditional(bra) => { - insert_mem_ssa_statement_default(id_def, &mut result, bra)? - } - Statement::Conversion(conv) => { - insert_mem_ssa_statement_default(id_def, &mut result, conv)? - } - Statement::PtrAccess(ptr_access) => { - insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? - } - Statement::RepackVector(repack) => { - insert_mem_ssa_statement_default(id_def, &mut result, repack)? - } - Statement::FunctionPointer(func_ptr) => { - insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)? - } - s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { - result.push(s) - } - _ => return Err(error_unreachable()), - } - } - Ok(result) -} - -fn insert_mem_ssa_argument( - id_def: &mut NumericIdResolver, - func: &mut Vec, - arg: &mut ast::Variable, - is_kernel: bool, -) { - if !is_kernel && arg.state_space == ast::StateSpace::Param { - return; - } - let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); - func.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: ast::StateSpace::Reg, - name: arg.name, - array_init: Vec::new(), - })); - func.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: arg.name, - src2: new_id, - }, - typ: arg.v_type.clone(), - member_index: None, - })); - arg.name = new_id; -} - -fn insert_mem_ssa_argument_reg_return( - func: &mut Vec, - arg: &ast::Variable, -) { - func.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: arg.v_type.clone(), - state_space: arg.state_space, - name: arg.name, - array_init: arg.array_init.clone(), - })); -} - -trait Visitable: Sized { - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError>; -} - -struct VisitArgumentDescriptor< - 'a, - Ctor: FnOnce(spirv::Word) -> Statement, U>, - U: ArgParamsEx, -> { - desc: ArgumentDescriptor, - typ: &'a ast::Type, - state_space: ast::StateSpace, - stmt_ctor: Ctor, -} - -impl< - 'a, - Ctor: FnOnce(spirv::Word) -> Statement, U>, - T: ArgParamsEx, - U: ArgParamsEx, - > Visitable for VisitArgumentDescriptor<'a, Ctor, U> -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok((self.stmt_ctor)( - visitor.id(self.desc, Some((self.typ, self.state_space)))?, - )) - } -} - -struct InsertMemSSAVisitor<'a, 'input> { - id_def: &'a mut NumericIdResolver<'input>, - func: &'a mut Vec, - post_statements: Vec, -} - -impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { - fn symbol( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - let symbol = desc.op.0; - if expected.is_none() { - return Ok(symbol); - }; - let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; - if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable { - return Ok(symbol); - }; - let member_index = match desc.op.1 { - Some(idx) => { - let vector_width = match var_type { - ast::Type::Vector(scalar_t, width) => { - var_type = ast::Type::Scalar(scalar_t); - width - } - _ => return Err(TranslateError::MismatchedType), - }; - Some(( - idx, - if self.id_def.special_registers.get(symbol).is_some() { - Some(vector_width) - } else { - None - }, - )) - } - None => None, - }; - let generated_id = self - .id_def - .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); - if !desc.is_dst { - self.func.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { - dst: generated_id, - src: symbol, - }, - state_space: ast::StateSpace::Reg, - typ: var_type, - member_index, - })); - } else { - self.post_statements - .push(Statement::StoreVar(StoreVarDetails { - arg: Arg2St { - src1: symbol, - src2: generated_id, - }, - typ: var_type, - member_index: member_index.map(|(idx, _)| idx), - })); - } - Ok(generated_id) - } -} - -impl<'a, 'input> ArgumentMapVisitor - for InsertMemSSAVisitor<'a, 'input> -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.symbol(desc.new_op((desc.op, None)), typ) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?) - } - TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset( - self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?, - offset, - ), - op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => TypedOperand::Reg( - self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?, - ), - }) - } -} - -fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( - id_def: &'a mut NumericIdResolver<'input>, - func: &'a mut Vec, - stmt: S, -) -> Result<(), TranslateError> { - let mut visitor = InsertMemSSAVisitor { - id_def, - func, - post_statements: Vec::new(), - }; - let new_stmt = stmt.visit(&mut visitor)?; - visitor.func.push(new_stmt); - visitor.func.extend(visitor.post_statements); - Ok(()) -} - -fn expand_arguments<'a, 'b>( - func: Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func { - match s { - Statement::Call(call) => { - let mut visitor = FlattenArguments::new(&mut result, id_def); - let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts); - result.push(Statement::Call(new_call)); - result.extend(post_stmts); - } - Statement::Instruction(inst) => { - let mut visitor = FlattenArguments::new(&mut result, id_def); - let (new_inst, post_stmts) = (inst.map(&mut visitor)?, visitor.post_stmts); - result.push(Statement::Instruction(new_inst)); - result.extend(post_stmts); - } - Statement::Variable(ast::Variable { - align, - v_type, - state_space, - name, - array_init, - }) => result.push(Statement::Variable(ast::Variable { - align, - v_type, - state_space, - name, - array_init, - })), - Statement::PtrAccess(ptr_access) => { - let mut visitor = FlattenArguments::new(&mut result, id_def); - let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts); - result.push(Statement::PtrAccess(new_inst)); - result.extend(post_stmts); - } - Statement::RepackVector(repack) => { - let mut visitor = FlattenArguments::new(&mut result, id_def); - let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts); - result.push(Statement::RepackVector(new_inst)); - result.extend(post_stmts); - } - Statement::Label(id) => result.push(Statement::Label(id)), - Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), - Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), - Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), - Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Constant(c) => result.push(Statement::Constant(c)), - Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), - } - } - Ok(result) -} - -struct FlattenArguments<'a, 'b> { - func: &'b mut Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, - post_stmts: Vec, -} - -impl<'a, 'b> FlattenArguments<'a, 'b> { - fn new( - func: &'b mut Vec, - id_def: &'b mut MutableNumericIdResolver<'a>, - ) -> Self { - FlattenArguments { - func, - id_def, - post_stmts: Vec::new(), - } - } - - fn reg( - &mut self, - desc: ArgumentDescriptor, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - Ok(desc.op) - } - - fn reg_offset( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, i32)>, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - let (reg, offset) = desc.op; - if !desc.is_memory_access { - let (reg_type, reg_space) = self.id_def.get_typed(reg)?; - if !reg_space.is_compatible(ast::StateSpace::Reg) { - return Err(TranslateError::MismatchedType); - } - let reg_scalar_type = match reg_type { - ast::Type::Scalar(underlying_type) => underlying_type, - _ => return Err(TranslateError::MismatchedType), - }; - let id_constant_stmt = self - .id_def - .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: reg_scalar_type, - value: ast::ImmediateValue::S64(offset as i64), - })); - let arith_details = match reg_scalar_type.kind() { - ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt { - typ: reg_scalar_type, - saturate: false, - }), - ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { - ast::ArithDetails::Unsigned(reg_scalar_type) - } - _ => return Err(error_unreachable()), - }; - let id_add_result = self.id_def.register_intermediate(reg_type, state_space); - self.func.push(Statement::Instruction(ast::Instruction::Add( - arith_details, - ast::Arg3 { - dst: id_add_result, - src1: reg, - src2: id_constant_stmt, - }, - ))); - Ok(id_add_result) - } else { - let id_constant_stmt = self.id_def.register_intermediate( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: ast::ScalarType::S64, - value: ast::ImmediateValue::S64(offset as i64), - })); - let dst = self.id_def.register_intermediate(typ.clone(), state_space); - self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: typ.clone(), - state_space: state_space, - dst, - ptr_src: reg, - offset_src: id_constant_stmt, - })); - Ok(dst) - } - } - - fn immediate( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - *scalar - } else { - todo!() - }; - let id = self - .id_def - .register_intermediate(ast::Type::Scalar(scalar_t), state_space); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value: desc.op, - })); - Ok(id) - } -} - -impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self.reg(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - match desc.op { - TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))), - TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space), - TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ, state_space) - } - TypedOperand::VecMember(..) => Err(error_unreachable()), - } - } -} - -/* - There are several kinds of implicit conversions in PTX: - * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands - * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size - - ld.param: not documented, but for instruction `ld.param. x, [y]`, - semantics are to first zext/chop/bitcast `y` as needed and then do - documented special ld/st/cvt conversion rules for destination operands - - st.param [x] y (used as function return arguments) same rule as above applies - - generic/global ld: for instruction `ld x, [y]`, y must be of type - b64/u64/s64, which is bitcast to a pointer, dereferenced and then - documented special ld/st/cvt conversion rules are applied to dst - - generic/global st: for instruction `st [x], y`, x must be of type - b64/u64/s64, which is bitcast to a pointer -*/ -fn insert_implicit_conversions( - func: Vec, - id_def: &mut MutableNumericIdResolver, -) -> Result, TranslateError> { - let mut result = Vec::with_capacity(func.len()); - for s in func.into_iter() { - match s { - Statement::Call(call) => { - insert_implicit_conversions_impl(&mut result, id_def, call)?; - } - Statement::Instruction(inst) => { - insert_implicit_conversions_impl(&mut result, id_def, inst)?; - } - Statement::PtrAccess(access) => { - insert_implicit_conversions_impl(&mut result, id_def, access)?; - } - Statement::RepackVector(repack) => { - insert_implicit_conversions_impl(&mut result, id_def, repack)?; - } - s @ Statement::Conditional(_) - | s @ Statement::Conversion(_) - | s @ Statement::Label(_) - | s @ Statement::Constant(_) - | s @ Statement::Variable(_) - | s @ Statement::LoadVar(..) - | s @ Statement::StoreVar(..) - | s @ Statement::RetValue(..) - | s @ Statement::FunctionPointer(..) => result.push(s), - } - } - Ok(result) -} - -fn insert_implicit_conversions_impl( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - stmt: impl Visitable, -) -> Result<(), TranslateError> { - let mut post_conv = Vec::new(); - let statement = - stmt.visit(&mut |desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>| { - let (instr_type, instruction_space) = match typ { - None => return Ok(desc.op), - Some(t) => t, - }; - let (operand_type, operand_space) = id_def.get_typed(desc.op)?; - let conversion_fn = desc - .non_default_implicit_conversion - .unwrap_or(default_implicit_conversion); - match conversion_fn( - (operand_space, &operand_type), - (instruction_space, instr_type), - )? { - Some(conv_kind) => { - let conv_output = if desc.is_dst { - &mut post_conv - } else { - &mut *func - }; - let mut from_type = instr_type.clone(); - let mut from_space = instruction_space; - let mut to_type = operand_type; - let mut to_space = operand_space; - let mut src = - id_def.register_intermediate(instr_type.clone(), instruction_space); - let mut dst = desc.op; - let result = Ok(src); - if !desc.is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from_type, &mut to_type); - mem::swap(&mut from_space, &mut to_space); - } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from_type, - from_space, - to_type, - to_space, - kind: conv_kind, - })); - result - } - None => Ok(desc.op), - } - })?; - func.push(statement); - func.append(&mut post_conv); - Ok(()) -} - -fn get_function_type( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - spirv_input: impl Iterator, - spirv_output: &[ast::Variable], -) -> (spirv::Word, spirv::Word) { - map.get_or_add_fn( - builder, - spirv_input, - spirv_output - .iter() - .map(|var| SpirvType::new(var.v_type.clone())), - ) -} - -fn emit_function_body_ops<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl: spirv::Word, - func: &[ExpandedStatement], -) -> Result<(), TranslateError> { - for s in func { - match s { - Statement::Label(id) => { - if builder.selected_block().is_some() { - builder.branch(*id)?; - } - builder.begin_block(Some(*id))?; - } - _ => { - if builder.selected_block().is_none() && builder.selected_function().is_some() { - builder.begin_block(None)?; - } - } - } - match s { - Statement::Label(_) => (), - Statement::Call(call) => { - let (result_type, result_id) = match &*call.return_arguments { - [(id, typ, space)] => { - if *space != ast::StateSpace::Reg { - return Err(error_unreachable()); - } - ( - map.get_or_add(builder, SpirvType::new(typ.clone())), - Some(*id), - ) - } - [] => (map.void(), None), - _ => todo!(), - }; - let arg_list = call - .input_arguments - .iter() - .map(|(id, _, _)| *id) - .collect::>(); - builder.function_call(result_type, result_id, call.name, arg_list)?; - } - Statement::Variable(var) => { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - Statement::Constant(cnst) => { - let typ_id = map.get_or_add_scalar(builder, cnst.typ); - match (cnst.typ, cnst.value) { - (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id, Some(cnst.dst), value); - } - (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64); - } - (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id, Some(cnst.dst), value as u64); - } - (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id, Some(cnst.dst), value as u64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { - builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32()); - } - (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { - builder.constant_f32(typ_id, Some(cnst.dst), value); - } - (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { - builder.constant_f64(typ_id, Some(cnst.dst), value as f64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { - builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32()); - } - (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { - builder.constant_f32(typ_id, Some(cnst.dst), value as f32); - } - (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { - builder.constant_f64(typ_id, Some(cnst.dst), value); - } - (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst)); - } else { - builder.constant_true(bool_type, Some(cnst.dst)); - } - } - (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst)); - } else { - builder.constant_true(bool_type, Some(cnst.dst)); - } - } - _ => return Err(TranslateError::MismatchedType), - } - } - Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, - Statement::Conditional(bra) => { - builder.branch_conditional( - bra.predicate, - bra.if_true, - bra.if_false, - iter::empty(), - )?; - } - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { - // TODO: implement properly - let zero = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U64), - &vec_repr(0u64), - )?; - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); - builder.copy_object(result_type, Some(*dst), zero)?; - } - Statement::Instruction(inst) => match inst { - ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?, - ast::Instruction::Call(_) => unreachable!(), - // SPIR-V does not support marking jumps as guaranteed-converged - ast::Instruction::Bra(_, arg) => { - builder.branch(arg.src)?; - } - ast::Instruction::Ld(data, arg) => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); - builder.load( - result_type, - Some(arg.dst), - arg.src, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - ast::Type::from(data.typ.clone()).size_of() as u32, - )] - .iter() - .cloned(), - )?; - } - ast::Instruction::St(data, arg) => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - builder.store( - arg.src1, - arg.src2, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - ast::Type::from(data.typ.clone()).size_of() as u32, - )] - .iter() - .cloned(), - )?; - } - // SPIR-V does not support ret as guaranteed-converged - ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(d, arg) => { - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone()))); - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } - ast::Instruction::Mul(mul, arg) => match mul { - ast::MulDetails::Signed(ref ctr) => { - emit_mul_sint(builder, map, opencl, ctr, arg)? - } - ast::MulDetails::Unsigned(ref ctr) => { - emit_mul_uint(builder, map, opencl, ctr, arg)? - } - ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?, - }, - ast::Instruction::Add(add, arg) => match add { - ast::ArithDetails::Signed(ref desc) => { - emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)? - } - ast::ArithDetails::Unsigned(ref desc) => { - emit_add_int(builder, map, (*desc).into(), false, arg)? - } - ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, - }, - ast::Instruction::Setp(setp, arg) => { - if arg.dst2.is_some() { - todo!() - } - emit_setp(builder, map, setp, arg)?; - } - ast::Instruction::Not(t, a) => { - let result_type = map.get_or_add(builder, SpirvType::from(*t)); - let result_id = Some(a.dst); - let operand = a.src; - match t { - ast::ScalarType::Pred => { - logical_not(builder, result_type, result_id, operand) - } - _ => builder.not(result_type, result_id, operand), - }?; - } - ast::Instruction::Shl(t, a) => { - let full_type = ast::Type::Scalar(*t); - let size_of = full_type.size_of(); - let result_type = map.get_or_add(builder, SpirvType::new(full_type)); - let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; - builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; - } - ast::Instruction::Shr(t, a) => { - let full_type = ast::ScalarType::from(*t); - let size_of = full_type.size_of(); - let result_type = map.get_or_add_scalar(builder, full_type); - let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; - if t.kind() == ast::ScalarKind::Signed { - builder.shift_right_arithmetic( - result_type, - Some(a.dst), - a.src1, - offset_src, - )?; - } else { - builder.shift_right_logical( - result_type, - Some(a.dst), - a.src1, - offset_src, - )?; - } - } - ast::Instruction::Cvt(dets, arg) => { - emit_cvt(builder, map, opencl, dets, arg)?; - } - ast::Instruction::Cvta(_, arg) => { - // This would be only meaningful if const/slm/global pointers - // had a different format than generic pointers, but they don't pretty much by ptx definition - // Honestly, I have no idea why this instruction exists and is emitted by the compiler - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } - ast::Instruction::SetpBool(_, _) => todo!(), - ast::Instruction::Mad(mad, arg) => match mad { - ast::MulDetails::Signed(ref desc) => { - emit_mad_sint(builder, map, opencl, desc, arg)? - } - ast::MulDetails::Unsigned(ref desc) => { - emit_mad_uint(builder, map, opencl, desc, arg)? - } - ast::MulDetails::Float(desc) => { - emit_mad_float(builder, map, opencl, desc, arg)? - } - }, - ast::Instruction::Fma(fma, arg) => emit_fma_float(builder, map, opencl, fma, arg)?, - ast::Instruction::Or(t, a) => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::ScalarType::Pred { - builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; - } else { - builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; - } - } - ast::Instruction::Sub(d, arg) => match d { - ast::ArithDetails::Signed(desc) => { - emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?; - } - ast::ArithDetails::Unsigned(desc) => { - emit_sub_int(builder, map, (*desc).into(), false, arg)?; - } - ast::ArithDetails::Float(desc) => { - emit_sub_float(builder, map, desc, arg)?; - } - }, - ast::Instruction::Min(d, a) => { - emit_min(builder, map, opencl, d, a)?; - } - ast::Instruction::Max(d, a) => { - emit_max(builder, map, opencl, d, a)?; - } - ast::Instruction::Rcp(d, a) => { - emit_rcp(builder, map, opencl, d, a)?; - } - ast::Instruction::And(t, a) => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::ScalarType::Pred { - builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; - } else { - builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; - } - } - ast::Instruction::Selp(t, a) => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - builder.select(result_type, Some(a.dst), a.src3, a.src1, a.src2)?; - } - // TODO: implement named barriers - ast::Instruction::Bar(d, _) => { - let workgroup_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(spirv::Scope::Workgroup as u32), - )?; - let barrier_semantics = match d { - ast::BarDetails::SyncAligned => map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr( - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - )?, - }; - builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?; - } - ast::Instruction::Atom(details, arg) => { - emit_atom(builder, map, details, arg)?; - } - ast::Instruction::AtomCas(details, arg) => { - let result_type = map.get_or_add_scalar(builder, details.typ.into()); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(details.scope.to_spirv() as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(details.semantics.to_spirv().bits()), - )?; - builder.atomic_compare_exchange( - result_type, - Some(arg.dst), - arg.src1, - memory_const, - semantics_const, - semantics_const, - arg.src3, - arg.src2, - )?; - } - ast::Instruction::Div(details, arg) => match details { - ast::DivDetails::Unsigned(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::DivDetails::Signed(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::DivDetails::Float(t) => { - let result_type = map.get_or_add_scalar(builder, t.typ.into()); - builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; - emit_float_div_decoration(builder, arg.dst, t.kind); - } - }, - ast::Instruction::Sqrt(details, a) => { - emit_sqrt(builder, map, opencl, details, a)?; - } - ast::Instruction::Rsqrt(details, a) => { - let result_type = map.get_or_add_scalar(builder, details.typ.into()); - builder.ext_inst( - result_type, - Some(a.dst), - opencl, - spirv::CLOp::rsqrt as spirv::Word, - [dr::Operand::IdRef(a.src)].iter().cloned(), - )?; - } - ast::Instruction::Neg(details, arg) => { - let result_type = map.get_or_add_scalar(builder, details.typ); - let negate_func = if details.typ.kind() == ast::ScalarKind::Float { - dr::Builder::f_negate - } else { - dr::Builder::s_negate - }; - negate_func(builder, result_type, Some(arg.dst), arg.src)?; - } - ast::Instruction::Sin { arg, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::sin as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - ast::Instruction::Cos { arg, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::cos as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - ast::Instruction::Lg2 { arg, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::log2 as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - ast::Instruction::Ex2 { arg, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::exp2 as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - ast::Instruction::Clz { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::clz as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - ast::Instruction::Brev { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder.bit_reverse(result_type, Some(arg.dst), arg.src)?; - } - ast::Instruction::Popc { typ, arg } => { - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder.bit_count(result_type, Some(arg.dst), arg.src)?; - } - ast::Instruction::Xor { typ, arg } => { - let builder_fn = match typ { - ast::ScalarType::Pred => emit_logical_xor_spirv, - _ => dr::Builder::bitwise_xor, - }; - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::Instruction::Bfe { .. } - | ast::Instruction::Bfi { .. } - | ast::Instruction::Activemask { .. } => { - // Should have beeen replaced with a funciton call earlier - return Err(error_unreachable()); - } - - ast::Instruction::Rem { typ, arg } => { - let builder_fn = if typ.kind() == ast::ScalarKind::Signed { - dr::Builder::s_mod - } else { - dr::Builder::u_mod - }; - let result_type = map.get_or_add_scalar(builder, (*typ).into()); - builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::Instruction::Prmt { control, arg } => { - let control = *control as u32; - let components = [ - (control >> 0) & 0b1111, - (control >> 4) & 0b1111, - (control >> 8) & 0b1111, - (control >> 12) & 0b1111, - ]; - if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo); - } - let vec4_b8_type = - map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); - let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?; - let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?; - let dst_vector = builder.vector_shuffle( - vec4_b8_type, - None, - src1_vector, - src2_vector, - components, - )?; - builder.bitcast(b32_type, Some(arg.dst), dst_vector)?; - } - ast::Instruction::Membar { level } => { - let (scope, semantics) = match level { - ast::MemScope::Cta => ( - spirv::Scope::Workgroup, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Gpu => ( - spirv::Scope::Device, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Sys => ( - spirv::Scope::CrossDevice, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - }; - let spirv_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope as u32), - )?; - let spirv_semantics = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics), - )?; - builder.memory_barrier(spirv_scope, spirv_semantics)?; - } - }, - Statement::LoadVar(details) => { - emit_load_var(builder, map, details)?; - } - Statement::StoreVar(details) => { - let dst_ptr = match details.member_index { - Some(index) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to( - details.typ.clone(), - spirv::StorageClass::Function, - ), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - builder.in_bounds_access_chain( - result_ptr_type, - None, - details.arg.src1, - [index_spirv].iter().copied(), - )? - } - None => details.arg.src1, - }; - builder.store(dst_ptr, details.arg.src2, None, iter::empty())?; - } - Statement::RetValue(_, id) => { - builder.ret_value(*id)?; - } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src, - }) => { - let u8_pointer = map.get_or_add( - builder, - SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), - ); - let result_type = map.get_or_add( - builder, - SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()), - ); - let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; - let temp = builder.in_bounds_ptr_access_chain( - u8_pointer, - None, - ptr_src_u8, - *offset_src, - iter::empty(), - )?; - builder.bitcast(result_type, Some(*dst), temp)?; - } - Statement::RepackVector(repack) => { - if repack.is_extract { - let scalar_type = map.get_or_add_scalar(builder, repack.typ); - for (index, dst_id) in repack.unpacked.iter().enumerate() { - builder.composite_extract( - scalar_type, - Some(*dst_id), - repack.packed, - [index as u32].iter().copied(), - )?; - } - } else { - let vector_type = map.get_or_add( - builder, - SpirvType::Vector( - SpirvScalarKey::from(repack.typ), - repack.unpacked.len() as u8, - ), - ); - let mut temp_vec = builder.undef(vector_type, None); - for (index, src_id) in repack.unpacked.iter().enumerate() { - temp_vec = builder.composite_insert( - vector_type, - None, - *src_id, - temp_vec, - [index as u32].iter().copied(), - )?; - } - builder.copy_object(vector_type, Some(repack.packed), temp_vec)?; - } - } - } - } - Ok(()) -} - -// HACK ALERT -// For some reason IGC fails linking if the value and shift size are of different type -fn insert_shift_hack( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - offset_var: spirv::Word, - size_of: usize, -) -> Result { - let result_type = match size_of { - 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), - 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), - 4 => return Ok(offset_var), - _ => return Err(error_unreachable()), - }; - Ok(builder.u_convert(result_type, None, offset_var)?) -} - -// TODO: check what kind of assembly do we emit -fn emit_logical_xor_spirv( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - op1: spirv::Word, - op2: spirv::Word, -) -> Result { - let temp_or = builder.logical_or(result_type, None, op1, op2)?; - let temp_and = builder.logical_and(result_type, None, op1, op2)?; - let temp_neg = logical_not(builder, result_type, None, temp_and)?; - builder.logical_and(result_type, result_id, temp_or, temp_neg) -} - -fn emit_sqrt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - details: &ast::SqrtDetails, - a: &ast::Arg2, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add_scalar(builder, details.typ.into()); - let (ocl_op, rounding) = match details.kind { - ast::SqrtKind::Approx => (spirv::CLOp::sqrt, None), - ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)), - }; - builder.ext_inst( - result_type, - Some(a.dst), - opencl, - ocl_op as spirv::Word, - [dr::Operand::IdRef(a.src)].iter().cloned(), - )?; - emit_rounding_decoration(builder, a.dst, rounding); - Ok(()) -} - -fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) { - match kind { - ast::DivFloatKind::Approx => { - builder.decorate( - dst, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - } - ast::DivFloatKind::Rounding(rnd) => { - emit_rounding_decoration(builder, dst, Some(rnd)); - } - ast::DivFloatKind::Full => {} - } -} - -fn emit_atom( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &ast::AtomDetails, - arg: &ast::Arg3, -) -> Result<(), TranslateError> { - let (spirv_op, typ) = match details.inner { - ast::AtomInnerDetails::Bit { op, typ } => { - let spirv_op = match op { - ast::AtomBitOp::And => dr::Builder::atomic_and, - ast::AtomBitOp::Or => dr::Builder::atomic_or, - ast::AtomBitOp::Xor => dr::Builder::atomic_xor, - ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange, - }; - (spirv_op, ast::ScalarType::from(typ)) - } - ast::AtomInnerDetails::Unsigned { op, typ } => { - let spirv_op = match op { - ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, - ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { - return Err(error_unreachable()); - } - ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, - ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, - }; - (spirv_op, typ.into()) - } - ast::AtomInnerDetails::Signed { op, typ } => { - let spirv_op = match op { - ast::AtomSIntOp::Add => dr::Builder::atomic_i_add, - ast::AtomSIntOp::Min => dr::Builder::atomic_s_min, - ast::AtomSIntOp::Max => dr::Builder::atomic_s_max, - }; - (spirv_op, typ.into()) - } - ast::AtomInnerDetails::Float { op, typ } => { - let spirv_op: fn(&mut dr::Builder, _, _, _, _, _, _) -> _ = match op { - ast::AtomFloatOp::Add => dr::Builder::atomic_f_add_ext, - }; - (spirv_op, typ.into()) - } - }; - let result_type = map.get_or_add_scalar(builder, typ); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(details.scope.to_spirv() as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(details.semantics.to_spirv().bits()), - )?; - spirv_op( - builder, - result_type, - Some(arg.dst), - arg.src1, - memory_const, - semantics_const, - arg.src2, - )?; - Ok(()) -} - -#[derive(Clone)] -struct PtxImplImport { - out_arg: ast::Type, - fn_id: u32, - in_args: Vec, -} - -fn emit_mul_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - ctr: &ast::ArithFloat, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - if ctr.saturate { - todo!() - } - let result_type = map.get_or_add_scalar(builder, ctr.typ.into()); - builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?; - emit_rounding_decoration(builder, arg.dst, ctr.rounding); - Ok(()) -} - -fn emit_rcp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::RcpDetails, - arg: &ast::Arg2, -) -> Result<(), TranslateError> { - let (instr_type, constant) = if desc.is_f64 { - (ast::ScalarType::F64, vec_repr(1.0f64)) - } else { - (ast::ScalarType::F32, vec_repr(1.0f32)) - }; - let result_type = map.get_or_add_scalar(builder, instr_type); - if !desc.is_f64 && desc.rounding.is_none() { - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::native_recip as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - return Ok(()); - } - let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; - builder.f_div(result_type, Some(arg.dst), one, arg.src)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - builder.decorate( - arg.dst, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - Ok(()) -} - -fn vec_repr(t: T) -> Vec { - let mut result = vec![0; mem::size_of::()]; - unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; - result -} - -fn emit_variable<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - linking: ast::LinkingDirective, - var: &ast::Variable, -) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.state_space { - ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { - (false, spirv::StorageClass::Function) - } - ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), - ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), - ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), - ast::StateSpace::Generic => todo!(), - ast::StateSpace::Sreg => todo!(), - }; - let initalizer = if var.array_init.len() > 0 { - Some(map.get_or_add_constant( - builder, - &ast::Type::from(var.v_type.clone()), - &*var.array_init, - )?) - } else if must_init { - let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); - Some(builder.constant_null(type_id, None)) - } else { - None - }; - let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); - builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); - if let Some(align) = var.align { - builder.decorate( - var.name, - spirv::Decoration::Alignment, - [dr::Operand::LiteralInt32(align)].iter().cloned(), - ); - } - if var.state_space != ast::StateSpace::Shared - || !linking.contains(ast::LinkingDirective::EXTERN) - { - emit_linking_decoration(builder, id_defs, None, var.name, linking); - } - Ok(()) -} - -fn emit_linking_decoration<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - name_override: Option<&str>, - name: spirv::Word, - linking: ast::LinkingDirective, -) { - if linking == ast::LinkingDirective::NONE { - return; - } - if linking.contains(ast::LinkingDirective::VISIBLE) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Export), - ] - .iter() - .cloned(), - ); - } else if linking.contains(ast::LinkingDirective::EXTERN) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Import), - ] - .iter() - .cloned(), - ); - } - // TODO: handle LinkingDirective::WEAK -} - -fn emit_mad_uint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulUInt, - arg: &ast::Arg4, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - match desc.control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?; - builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::u_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1), - dr::Operand::IdRef(arg.src2), - dr::Operand::IdRef(arg.src3), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_mad_sint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulSInt, - arg: &ast::Arg4, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - match desc.control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?; - builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::s_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1), - dr::Operand::IdRef(arg.src2), - dr::Operand::IdRef(arg.src3), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_fma_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::Arg4, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::fma as spirv::Word, - [ - dr::Operand::IdRef(arg.src1), - dr::Operand::IdRef(arg.src2), - dr::Operand::IdRef(arg.src3), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_mad_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::Arg4, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::mad as spirv::Word, - [ - dr::Operand::IdRef(arg.src1), - dr::Operand::IdRef(arg.src2), - dr::Operand::IdRef(arg.src3), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_add_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_sub_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_min( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, - }; - let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - cl_op as spirv::Word, - [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_max( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, - }; - let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type())); - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - cl_op as spirv::Word, - [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_cvt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - dets: &ast::CvtDetails, - arg: &ast::Arg2, -) -> Result<(), TranslateError> { - match dets { - ast::CvtDetails::FloatFromFloat(desc) => { - if desc.saturate { - todo!() - } - let dest_t: ast::ScalarType = desc.dst.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.dst == desc.src { - match desc.rounding { - Some(ast::RoundingMode::NearestEven) => { - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::rint as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::Zero) => { - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::trunc as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::NegativeInf) => { - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::floor as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::PositiveInf) => { - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - spirv::CLOp::ceil as u32, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - } - None => { - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } - } - } else { - builder.f_convert(result_type, Some(arg.dst), arg.src)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - } - } - ast::CvtDetails::FloatFromInt(desc) => { - if desc.saturate { - todo!() - } - let dest_t: ast::ScalarType = desc.dst.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.src.kind() == ast::ScalarKind::Signed { - builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; - } else { - builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; - } - emit_rounding_decoration(builder, arg.dst, desc.rounding); - } - ast::CvtDetails::IntFromFloat(desc) => { - let dest_t: ast::ScalarType = desc.dst.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.dst.kind() == ast::ScalarKind::Signed { - builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; - } else { - builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; - } - emit_rounding_decoration(builder, arg.dst, desc.rounding); - emit_saturating_decoration(builder, arg.dst, desc.saturate); - } - ast::CvtDetails::IntFromInt(desc) => { - let dest_t: ast::ScalarType = desc.dst.into(); - let src_t: ast::ScalarType = desc.src.into(); - // first do shortening/widening - let src = if desc.dst.size_of() != desc.src.size_of() { - let new_dst = if dest_t.kind() == src_t.kind() { - arg.dst - } else { - builder.id() - }; - let cv = ImplicitConversion { - src: arg.src, - dst: new_dst, - from_type: ast::Type::Scalar(src_t), - from_space: ast::StateSpace::Reg, - to_type: ast::Type::Scalar(ast::ScalarType::from_parts( - dest_t.size_of(), - src_t.kind(), - )), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::Default, - }; - emit_implicit_conversion(builder, map, &cv)?; - new_dst - } else { - arg.src - }; - if dest_t.kind() == src_t.kind() { - return Ok(()); - } - // now do actual conversion - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - if desc.saturate { - if desc.dst.kind() == ast::ScalarKind::Signed { - builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; - } else { - builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; - } - } else { - builder.bitcast(result_type, Some(arg.dst), src)?; - } - } - } - Ok(()) -} - -fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) { - if saturate { - builder.decorate(dst, spirv::Decoration::SaturatedConversion, iter::empty()); - } -} - -fn emit_rounding_decoration( - builder: &mut dr::Builder, - dst: spirv::Word, - rounding: Option, -) { - if let Some(rounding) = rounding { - builder.decorate( - dst, - spirv::Decoration::FPRoundingMode, - [rounding.to_spirv()].iter().cloned(), - ); - } -} - -impl ast::RoundingMode { - fn to_spirv(self) -> rspirv::dr::Operand { - let mode = match self { - ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, - ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, - ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, - ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, - }; - rspirv::dr::Operand::FPRoundingMode(mode) - } -} - -fn emit_setp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - setp: &ast::SetpData, - arg: &ast::Arg4Setp, -) -> Result<(), dr::Error> { - let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)); - let result_id = Some(arg.dst1); - let operand_1 = arg.src1; - let operand_2 = arg.src2; - match (setp.cmp_op, setp.typ.kind()) { - (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed) - | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => { - builder.i_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => { - builder.f_ord_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed) - | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => { - builder.i_not_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => { - builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => { - builder.u_less_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => { - builder.s_less_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => { - builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => { - builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => { - builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => { - builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => { - builder.u_greater_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => { - builder.s_greater_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => { - builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned) - | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => { - builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => { - builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => { - builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanEq, _) => { - builder.f_unord_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanNotEq, _) => { - builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanLess, _) => { - builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanLessOrEq, _) => { - builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanGreater, _) => { - builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::NanGreaterOrEq, _) => { - builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - (ast::SetpCompareOp::IsAnyNan, _) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - builder.logical_or(result_type, result_id, temp1, temp2) - } - (ast::SetpCompareOp::IsNotNan, _) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; - logical_not(builder, result_type, result_id, any_nan) - } - _ => todo!(), - }?; - Ok(()) -} - -// HACK ALERT -// Temporary workaround until IGC gets its shit together -// Currently IGC carries two copies of SPIRV-LLVM translator -// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. -// Obviously, old and buggy one is used for compiling L0 SPIRV -// https://github.com/intel/intel-graphics-compiler/issues/148 -fn logical_not( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - operand: spirv::Word, -) -> Result { - let const_true = builder.constant_true(result_type, None); - let const_false = builder.constant_false(result_type, None); - builder.select(result_type, result_id, operand, const_false, const_true) -} - -fn emit_mul_sint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulSInt, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let instruction_type = desc.typ; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.typ)); - match desc.control { - ast::MulIntControl::Low => { - builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::s_mul_hi as spirv::Word, - [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => { - let instr_width = instruction_type.size_of(); - let instr_kind = instruction_type.kind(); - let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); - let dst_type_id = map.get_or_add_scalar(builder, dst_type); - let src1 = builder.s_convert(dst_type_id, None, arg.src1)?; - let src2 = builder.s_convert(dst_type_id, None, arg.src2)?; - builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?; - builder.decorate(arg.dst, spirv::Decoration::NoSignedWrap, iter::empty()); - } - } - Ok(()) -} - -fn emit_mul_uint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MulUInt, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - let instruction_type = ast::ScalarType::from(desc.typ); - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); - match desc.control { - ast::MulIntControl::Low => { - builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst), - opencl, - spirv::CLOp::u_mul_hi as spirv::Word, - [dr::Operand::IdRef(arg.src1), dr::Operand::IdRef(arg.src2)] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => { - let instr_width = instruction_type.size_of(); - let instr_kind = instruction_type.kind(); - let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); - let dst_type_id = map.get_or_add_scalar(builder, dst_type); - let src1 = builder.u_convert(dst_type_id, None, arg.src1)?; - let src2 = builder.u_convert(dst_type_id, None, arg.src2)?; - builder.i_mul(dst_type_id, Some(arg.dst), src1, src2)?; - builder.decorate(arg.dst, spirv::Decoration::NoUnsignedWrap, iter::empty()); - } - } - Ok(()) -} - -// Surprisingly, structs can't be bitcast, so we route everything through a vector -fn struct2_bitcast_to_wide( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - base_type_key: SpirvScalarKey, - instruction_type: spirv::Word, - dst: spirv::Word, - dst_type_id: spirv::Word, - src: spirv::Word, -) -> Result<(), dr::Error> { - let low_bits = builder.composite_extract(instruction_type, None, src, [0].iter().copied())?; - let high_bits = builder.composite_extract(instruction_type, None, src, [1].iter().copied())?; - let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2)); - let vector = - builder.composite_construct(vector_type, None, [low_bits, high_bits].iter().copied())?; - builder.bitcast(dst_type_id, Some(dst), vector)?; - Ok(()) -} - -fn emit_abs( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - d: &ast::AbsDetails, - arg: &ast::Arg2, -) -> Result<(), dr::Error> { - let scalar_t = ast::ScalarType::from(d.typ); - let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { - spirv::CLOp::s_abs - } else { - spirv::CLOp::fabs - }; - builder.ext_inst( - result_type, - Some(arg.dst), - opencl, - cl_abs as spirv::Word, - [dr::Operand::IdRef(arg.src)].iter().cloned(), - )?; - Ok(()) -} - -fn emit_add_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); - builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - Ok(()) -} - -fn emit_sub_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::Arg3, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); - builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; - Ok(()) -} - -fn emit_implicit_conversion( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - cv: &ImplicitConversion, -) -> Result<(), TranslateError> { - let from_parts = cv.from_type.to_parts(); - let to_parts = cv.to_type.to_parts(); - match (from_parts.kind, to_parts.kind, &cv.kind) { - (_, _, &ConversionKind::BitToPtr) => { - let dst_type = map.get_or_add( - builder, - SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()), - ); - builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { - if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - if from_parts.scalar_kind != ast::ScalarKind::Float - && to_parts.scalar_kind != ast::ScalarKind::Float - { - // It is noop, but another instruction expects result of this conversion - builder.copy_object(dst_type, Some(cv.dst), cv.src)?; - } else { - builder.bitcast(dst_type, Some(cv.dst), cv.src)?; - } - } else { - // This block is safe because it's illegal to implictly convert between floating point values - let same_width_bit_type = map.get_or_add( - builder, - SpirvType::new(ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - })), - ); - let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; - let wide_bit_type = ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..to_parts - }); - let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); - if to_parts.scalar_kind == ast::ScalarKind::Unsigned - || to_parts.scalar_kind == ast::ScalarKind::Bit - { - builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; - } else { - let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed - && to_parts.scalar_kind == ast::ScalarKind::Signed - { - dr::Builder::s_convert - } else { - dr::Builder::u_convert - }; - let wide_bit_value = - conversion_fn(builder, wide_bit_type_spirv, None, same_width_bit_value)?; - emit_implicit_conversion( - builder, - map, - &ImplicitConversion { - src: wide_bit_value, - dst: cv.dst, - from_type: wide_bit_type, - from_space: cv.from_space, - to_type: cv.to_type.clone(), - to_space: cv.to_space, - kind: ConversionKind::Default, - }, - )?; - } - } - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.s_convert(result_type, Some(cv.dst), cv.src)?; - } - (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.bitcast(into_type, Some(cv.dst), cv.src)?; - } - (_, _, &ConversionKind::PtrToPtr) => { - let result_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - cv.to_space.to_spirv(), - ), - ); - if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - cv.from_space.to_spirv(), - ), - ); - builder.bitcast(temp_type, None, cv.src)? - } else { - cv.src - }; - builder.ptr_cast_to_generic(result_type, Some(cv.dst), src)?; - } else if cv.from_space == ast::StateSpace::Generic - && cv.to_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - cv.from_space.to_spirv(), - ), - ); - builder.bitcast(temp_type, None, cv.src)? - } else { - cv.src - }; - builder.generic_cast_to_ptr(result_type, Some(cv.dst), src)?; - } else { - builder.bitcast(result_type, Some(cv.dst), cv.src)?; - } - } - (_, _, &ConversionKind::AddressOf) => { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; - } - (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?; - } - (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?; - } - _ => unreachable!(), - } - Ok(()) -} - -fn emit_load_var( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &LoadVarDetails, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); - match details.member_index { - Some((index, Some(width))) => { - let vector_type = match details.typ { - ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), - _ => return Err(TranslateError::MismatchedType), - }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); - let vector_temp = builder.load( - vector_type_spirv, - None, - details.arg.src, - None, - iter::empty(), - )?; - builder.composite_extract( - result_type, - Some(details.arg.dst), - vector_temp, - [index as u32].iter().copied(), - )?; - } - Some((index, None)) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - let src = builder.in_bounds_access_chain( - result_ptr_type, - None, - details.arg.src, - [index_spirv].iter().copied(), - )?; - builder.load(result_type, Some(details.arg.dst), src, None, iter::empty())?; - } - None => { - builder.load( - result_type, - Some(details.arg.dst), - details.arg.src, - None, - iter::empty(), - )?; - } - }; - Ok(()) -} - -fn normalize_identifiers<'input, 'b>( - id_defs: &mut FnStringIdResolver<'input, 'b>, - fn_defs: &GlobalFnDeclResolver<'input, 'b>, - func: Vec>>, -) -> Result, TranslateError> { - for s in func.iter() { - match s { - ast::Statement::Label(id) => { - id_defs.add_def(*id, None, false); - } - _ => (), - } - } - let mut result = Vec::new(); - for s in func { - expand_map_variables(id_defs, fn_defs, &mut result, s)?; - } - Ok(result) -} - -fn expand_map_variables<'a, 'b>( - id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver<'a, 'b>, - result: &mut Vec, - s: ast::Statement>, -) -> Result<(), TranslateError> { - match s { - ast::Statement::Block(block) => { - id_defs.start_block(); - for s in block { - expand_map_variables(id_defs, fn_defs, result, s)?; - } - id_defs.end_block(); - } - ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), - ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( - p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))) - .transpose()?, - i.map_variable(&mut |id| id_defs.get_id(id))?, - ))), - ast::Statement::Variable(var) => { - let var_type = var.var.v_type.clone(); - match var.count { - Some(count) => { - for new_id in - id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) - { - result.push(Statement::Variable(ast::Variable { - align: var.var.align, - v_type: var.var.v_type.clone(), - state_space: var.var.state_space, - name: new_id, - array_init: var.var.array_init.clone(), - })) - } - } - None => { - let new_id = - id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); - result.push(Statement::Variable(ast::Variable { - align: var.var.align, - v_type: var.var.v_type.clone(), - state_space: var.var.state_space, - name: new_id, - array_init: var.var.array_init, - })); - } - } - } - }; - Ok(()) -} - -/* - Our goal here is to transform - .visible .entry foobar(.param .u64 input) { - .reg .b64 in_addr; - .reg .b64 in_addr2; - ld.param.u64 in_addr, [input]; - cvta.to.global.u64 in_addr2, in_addr; - } - into: - .visible .entry foobar(.param .u8 input[]) { - .reg .u8 in_addr[]; - .reg .u8 in_addr2[]; - ld.param.u8[] in_addr, [input]; - mov.u8[] in_addr2, in_addr; - } - or: - .visible .entry foobar(.reg .u8 input[]) { - .reg .u8 in_addr[]; - .reg .u8 in_addr2[]; - mov.u8[] in_addr, input; - mov.u8[] in_addr2, in_addr; - } - or: - .visible .entry foobar(.param ptr input) { - .reg ptr in_addr; - .reg ptr in_addr2; - ld.param.ptr in_addr, [input]; - mov.ptr in_addr2, in_addr; - } -*/ -// TODO: detect more patterns (mov, call via reg, call via param) -// TODO: don't convert to ptr if the register is not ultimately used for ld/st -// TODO: once insert_mem_ssa_statements is moved to later, move this pass after -// argument expansion -// TODO: propagate out of calls and into calls -fn convert_to_stateful_memory_access<'a, 'input>( - func_args: Rc>>, - func_body: Vec, - id_defs: &mut NumericIdResolver<'a>, -) -> Result< - ( - Rc>>, - Vec, - ), - TranslateError, -> { - let mut method_decl = func_args.borrow_mut(); - if !method_decl.name.is_kernel() { - drop(method_decl); - return Ok((func_args, func_body)); - } - if Rc::strong_count(&func_args) != 1 { - return Err(error_unreachable()); - } - let func_args_64bit = (*method_decl) - .input_arguments - .iter() - .filter_map(|arg| match arg.v_type { - ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), - _ => None, - }) - .collect::>(); - let mut stateful_markers = Vec::new(); - let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); - for statement in func_body.iter() { - match statement { - Statement::Instruction(ast::Instruction::Cvta( - ast::CvtaDetails { - to: ast::StateSpace::Global, - size: ast::CvtaSize::U64, - from: ast::StateSpace::Generic, - }, - arg, - )) => { - if let (TypedOperand::Reg(dst), Some(src)) = - (arg.dst, arg.src.upcast().underlying_register()) - { - if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { - stateful_markers.push((dst, *src)); - } - } - } - Statement::Instruction(ast::Instruction::Ld( - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::U64), - .. - }, - arg, - )) - | Statement::Instruction(ast::Instruction::Ld( - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::S64), - .. - }, - arg, - )) - | Statement::Instruction(ast::Instruction::Ld( - ast::LdDetails { - state_space: ast::StateSpace::Param, - typ: ast::Type::Scalar(ast::ScalarType::B64), - .. - }, - arg, - )) => { - if let (TypedOperand::Reg(dst), Some(src)) = - (&arg.dst, arg.src.upcast().underlying_register()) - { - if func_args_64bit.contains(src) { - multi_hash_map_append(&mut stateful_init_reg, *dst, *src); - } - } - } - _ => {} - } - } - if stateful_markers.len() == 0 { - drop(method_decl); - return Ok((func_args, func_body)); - } - let mut func_args_ptr = HashSet::new(); - let mut regs_ptr_current = HashSet::new(); - for (dst, src) in stateful_markers { - if let Some(func_args) = stateful_init_reg.get(&src) { - for a in func_args { - func_args_ptr.insert(*a); - regs_ptr_current.insert(src); - regs_ptr_current.insert(dst); - } - } - } - // BTreeSet here to have a stable order of iteration, - // unfortunately our tests rely on it - let mut regs_ptr_seen = BTreeSet::new(); - while regs_ptr_current.len() > 0 { - let mut regs_ptr_new = HashSet::new(); - for statement in func_body.iter() { - match statement { - Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::ScalarType::U64), - arg, - )) - | Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, - saturate: false, - }), - arg, - )) - | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::ScalarType::U64), - arg, - )) - | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, - saturate: false, - }), - arg, - )) => { - // TODO: don't mark result of double pointer sub or double - // pointer add as ptr result - if let (TypedOperand::Reg(dst), Some(src1)) = - (arg.dst, arg.src1.upcast().underlying_register()) - { - if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { - regs_ptr_new.insert(dst); - } - } else if let (TypedOperand::Reg(dst), Some(src2)) = - (arg.dst, arg.src2.upcast().underlying_register()) - { - if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { - regs_ptr_new.insert(dst); - } - } - } - _ => {} - } - } - for id in regs_ptr_current { - regs_ptr_seen.insert(id); - } - regs_ptr_current = regs_ptr_new; - } - drop(regs_ptr_current); - let mut remapped_ids = HashMap::new(); - let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); - for reg in regs_ptr_seen { - let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Reg, - ); - result.push(Statement::Variable(ast::Variable { - align: None, - name: new_id, - array_init: Vec::new(), - v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - state_space: ast::StateSpace::Reg, - })); - remapped_ids.insert(reg, new_id); - } - for arg in (*method_decl).input_arguments.iter_mut() { - if !func_args_ptr.contains(&arg.name) { - continue; - } - let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Param, - ); - let old_name = arg.name; - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); - arg.name = new_id; - remapped_ids.insert(old_name, new_id); - } - for statement in func_body { - match statement { - l @ Statement::Label(_) => result.push(l), - c @ Statement::Conditional(_) => result.push(c), - c @ Statement::Constant(..) => result.push(c), - Statement::Variable(var) => { - if !remapped_ids.contains_key(&var.name) { - result.push(Statement::Variable(var)); - } - } - Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Unsigned(ast::ScalarType::U64), - arg, - )) - | Statement::Instruction(ast::Instruction::Add( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, - saturate: false, - }), - arg, - )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.upcast().underlying_register() { - Some(src1) if remapped_ids.contains_key(src1) => { - (remapped_ids.get(src1).unwrap(), arg.src2) - } - Some(src2) if remapped_ids.contains_key(src2) => { - (remapped_ids.get(src2).unwrap(), arg.src1) - } - _ => return Err(error_unreachable()), - }; - let dst = arg.dst.upcast().unwrap_reg()?; - result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, - dst: *remapped_ids.get(&dst).unwrap(), - ptr_src: *ptr, - offset_src: offset, - })) - } - Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Unsigned(ast::ScalarType::U64), - arg, - )) - | Statement::Instruction(ast::Instruction::Sub( - ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::ScalarType::S64, - saturate: false, - }), - arg, - )) if is_sub_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.upcast().underlying_register() { - Some(src1) => (remapped_ids.get(src1).unwrap(), arg.src2), - _ => return Err(error_unreachable()), - }; - let offset_neg = id_defs.register_intermediate(Some(( - ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - ))); - result.push(Statement::Instruction(ast::Instruction::Neg( - ast::NegDetails { - typ: ast::ScalarType::S64, - flush_to_zero: None, - }, - ast::Arg2 { - src: offset, - dst: TypedOperand::Reg(offset_neg), - }, - ))); - let dst = arg.dst.upcast().unwrap_reg()?; - result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::Type::Scalar(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, - dst: *remapped_ids.get(&dst).unwrap(), - ptr_src: *ptr, - offset_src: TypedOperand::Reg(offset_neg), - })) - } - Statement::Instruction(inst) => { - let mut post_statements = Vec::new(); - let new_statement = - inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { - convert_to_stateful_memory_access_postprocess( - id_defs, - &remapped_ids, - &mut result, - &mut post_statements, - arg_desc, - expected_type, - ) - })?; - result.push(new_statement); - result.extend(post_statements); - } - Statement::Call(call) => { - let mut post_statements = Vec::new(); - let new_statement = - call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { - convert_to_stateful_memory_access_postprocess( - id_defs, - &remapped_ids, - &mut result, - &mut post_statements, - arg_desc, - expected_type, - ) - })?; - result.push(new_statement); - result.extend(post_statements); - } - Statement::RepackVector(pack) => { - let mut post_statements = Vec::new(); - let new_statement = - pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { - convert_to_stateful_memory_access_postprocess( - id_defs, - &remapped_ids, - &mut result, - &mut post_statements, - arg_desc, - expected_type, - ) - })?; - result.push(new_statement); - result.extend(post_statements); - } - _ => return Err(error_unreachable()), - } - } - drop(method_decl); - Ok((func_args, result)) -} - -fn convert_to_stateful_memory_access_postprocess( - id_defs: &mut NumericIdResolver, - remapped_ids: &HashMap, - result: &mut Vec, - post_statements: &mut Vec, - arg_desc: ArgumentDescriptor, - expected_type: Option<(&ast::Type, ast::StateSpace)>, -) -> Result { - Ok(match remapped_ids.get(&arg_desc.op) { - Some(new_id) => { - let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; - if let Some((expected_type, expected_space)) = expected_type { - let implicit_conversion = arg_desc - .non_default_implicit_conversion - .unwrap_or(default_implicit_conversion); - if implicit_conversion( - (new_operand_space, &new_operand_type), - (expected_space, expected_type), - ) - .is_ok() - { - return Ok(*new_id); - } - } - let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; - let converting_id = - id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { - ConversionKind::Default - } else { - ConversionKind::PtrToPtr - }; - if arg_desc.is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_operand_type, - from_space: old_operand_space, - to_type: new_operand_type, - to_space: new_operand_space, - kind, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: new_operand_type, - from_space: new_operand_space, - to_type: old_operand_type, - to_space: old_operand_space, - kind, - })); - converting_id - } - } - None => arg_desc.op, - }) -} - -fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { - match arg.dst { - TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { - return false - } - TypedOperand::Reg(dst) => { - if !remapped_ids.contains_key(&dst) { - return false; - } - if let Some(src1_reg) = arg.src1.upcast().underlying_register() { - if remapped_ids.contains_key(src1_reg) { - // don't trigger optimization when adding two pointers - if let Some(src2_reg) = arg.src2.upcast().underlying_register() { - return !remapped_ids.contains_key(src2_reg); - } - } - } - if let Some(src2_reg) = arg.src2.upcast().underlying_register() { - remapped_ids.contains_key(src2_reg) - } else { - false - } - } - } -} - -fn is_sub_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { - match arg.dst { - TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { - return false - } - TypedOperand::Reg(dst) => { - if !remapped_ids.contains_key(&dst) { - return false; - } - match arg.src1.upcast().underlying_register() { - Some(src1_reg) => { - if remapped_ids.contains_key(src1_reg) { - // don't trigger optimization when subtracting two pointers - arg.src2 - .upcast() - .underlying_register() - .map_or(true, |src2_reg| !remapped_ids.contains_key(src2_reg)) - } else { - false - } - } - None => false, - } - } - } -} - -fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { - match id_defs.get_typed(id) { - Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) - | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) - | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, - _ => false, - } -} - -#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] -enum PtxSpecialRegister { - Tid, - Ntid, - Ctaid, - Nctaid, - Clock, - LanemaskLt, -} - -impl PtxSpecialRegister { - fn try_parse(s: &str) -> Option { - match s { - "%tid" => Some(Self::Tid), - "%ntid" => Some(Self::Ntid), - "%ctaid" => Some(Self::Ctaid), - "%nctaid" => Some(Self::Nctaid), - "%clock" => Some(Self::Clock), - "%lanemask_lt" => Some(Self::LanemaskLt), - _ => None, - } - } - - fn get_type(self) -> ast::Type { - match self { - PtxSpecialRegister::Tid - | PtxSpecialRegister::Ntid - | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), - _ => ast::Type::Scalar(self.get_function_return_type()), - } - } - - fn get_function_return_type(self) -> ast::ScalarType { - match self { - PtxSpecialRegister::Tid => ast::ScalarType::U32, - PtxSpecialRegister::Ntid => ast::ScalarType::U32, - PtxSpecialRegister::Ctaid => ast::ScalarType::U32, - PtxSpecialRegister::Nctaid => ast::ScalarType::U32, - PtxSpecialRegister::Clock => ast::ScalarType::U32, - PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, - } - } - - fn get_function_input_type(self) -> Option { - match self { - PtxSpecialRegister::Tid - | PtxSpecialRegister::Ntid - | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), - PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, - } - } - - fn get_unprefixed_function_name(self) -> &'static str { - match self { - PtxSpecialRegister::Tid => "sreg_tid", - PtxSpecialRegister::Ntid => "sreg_ntid", - PtxSpecialRegister::Ctaid => "sreg_ctaid", - PtxSpecialRegister::Nctaid => "sreg_nctaid", - PtxSpecialRegister::Clock => "sreg_clock", - PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", - } - } -} - -struct SpecialRegistersMap { - reg_to_id: HashMap, - id_to_reg: HashMap, -} - -impl SpecialRegistersMap { - fn new() -> Self { - SpecialRegistersMap { - reg_to_id: HashMap::new(), - id_to_reg: HashMap::new(), - } - } - - fn get(&self, id: spirv::Word) -> Option { - self.id_to_reg.get(&id).copied() - } - - fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word { - match self.reg_to_id.entry(reg) { - hash_map::Entry::Occupied(e) => *e.get(), - hash_map::Entry::Vacant(e) => { - let numeric_id = *current_id; - *current_id += 1; - e.insert(numeric_id); - self.id_to_reg.insert(numeric_id, reg); - numeric_id - } - } - } -} - -struct FnSigMapper<'input> { - // true - stays as return argument - // false - is moved to input argument - return_param_args: Vec, - func_decl: Rc>>, -} - -impl<'input> FnSigMapper<'input> { - fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self { - let return_param_args = method - .return_arguments - .iter() - .map(|a| a.state_space != ast::StateSpace::Param) - .collect::>(); - let mut new_return_arguments = Vec::new(); - for arg in method.return_arguments.into_iter() { - if arg.state_space == ast::StateSpace::Param { - method.input_arguments.push(arg); - } else { - new_return_arguments.push(arg); - } - } - method.return_arguments = new_return_arguments; - FnSigMapper { - return_param_args, - func_decl: Rc::new(RefCell::new(method)), - } - } - - fn resolve_in_spirv_repr( - &self, - call_inst: ast::CallInst, - ) -> Result, TranslateError> { - let func_decl = (*self.func_decl).borrow(); - let mut return_arguments = Vec::new(); - let mut input_arguments = call_inst - .param_list - .into_iter() - .zip(func_decl.input_arguments.iter()) - .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) - .collect::>(); - let mut func_decl_return_iter = func_decl.return_arguments.iter(); - let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); - for (idx, id) in call_inst.ret_params.iter().enumerate() { - let stays_as_return = match self.return_param_args.get(idx) { - Some(x) => *x, - None => return Err(TranslateError::MismatchedType), - }; - if stays_as_return { - if let Some(var) = func_decl_return_iter.next() { - return_arguments.push((*id, var.v_type.clone(), var.state_space)); - } else { - return Err(TranslateError::MismatchedType); - } - } else { - if let Some(var) = func_decl_input_iter.next() { - input_arguments.push(( - ast::Operand::Reg(*id), - var.v_type.clone(), - var.state_space, - )); - } else { - return Err(TranslateError::MismatchedType); - } - } - } - if return_arguments.len() != func_decl.return_arguments.len() - || input_arguments.len() != func_decl.input_arguments.len() - { - return Err(TranslateError::MismatchedType); - } - Ok(ResolvedCall { - return_arguments, - input_arguments, - uniform: call_inst.uniform, - name: call_inst.func, - }) - } -} - -struct GlobalStringIdResolver<'input> { - current_id: spirv::Word, - variables: HashMap, spirv::Word>, - reverse_variables: HashMap, - variables_type_check: HashMap>, - special_registers: SpecialRegistersMap, - fns: HashMap>, -} - -impl<'input> GlobalStringIdResolver<'input> { - fn new(start_id: spirv::Word) -> Self { - Self { - current_id: start_id, - variables: HashMap::new(), - reverse_variables: HashMap::new(), - variables_type_check: HashMap::new(), - special_registers: SpecialRegistersMap::new(), - fns: HashMap::new(), - } - } - - fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word { - self.get_or_add_impl(id, None) - } - - fn get_or_add_def_typed( - &mut self, - id: &'input str, - typ: ast::Type, - state_space: ast::StateSpace, - is_variable: bool, - ) -> spirv::Word { - self.get_or_add_impl(id, Some((typ, state_space, is_variable))) - } - - fn get_or_add_impl( - &mut self, - id: &'input str, - typ: Option<(ast::Type, ast::StateSpace, bool)>, - ) -> spirv::Word { - let id = match self.variables.entry(Cow::Borrowed(id)) { - hash_map::Entry::Occupied(e) => *(e.get()), - hash_map::Entry::Vacant(e) => { - let numeric_id = self.current_id; - e.insert(numeric_id); - self.reverse_variables.insert(numeric_id, id); - self.current_id += 1; - numeric_id - } - }; - self.variables_type_check.insert(id, typ); - id - } - - fn get_id(&self, id: &str) -> Result { - self.variables - .get(id) - .copied() - .ok_or_else(error_unknown_symbol) - } - - fn current_id(&self) -> spirv::Word { - self.current_id - } - - fn start_fn<'b>( - &'b mut self, - header: &'b ast::MethodDeclaration<'input, &'input str>, - ) -> Result< - ( - FnStringIdResolver<'input, 'b>, - GlobalFnDeclResolver<'input, 'b>, - Rc>>, - ), - TranslateError, - > { - // In case a function decl was inserted earlier we want to use its id - let name_id = self.get_or_add_def(header.name()); - let mut fn_resolver = FnStringIdResolver { - current_id: &mut self.current_id, - global_variables: &self.variables, - global_type_check: &self.variables_type_check, - special_registers: &mut self.special_registers, - variables: vec![HashMap::new(); 1], - type_check: HashMap::new(), - }; - let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); - let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); - let name = match header.name { - ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), - ast::MethodName::Func(_) => ast::MethodName::Func(name_id), - }; - let fn_decl = ast::MethodDeclaration { - return_arguments, - name, - input_arguments, - shared_mem: None, - }; - let new_fn_decl = if !fn_decl.name.is_kernel() { - let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); - let new_fn_decl = resolver.func_decl.clone(); - self.fns.insert(name_id, resolver); - new_fn_decl - } else { - Rc::new(RefCell::new(fn_decl)) - }; - Ok(( - fn_resolver, - GlobalFnDeclResolver { fns: &self.fns }, - new_fn_decl, - )) - } -} - -pub struct GlobalFnDeclResolver<'input, 'a> { - fns: &'a HashMap>, -} - -impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { - fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> { - self.fns.get(&id).ok_or_else(error_unknown_symbol) - } -} - -struct FnStringIdResolver<'input, 'b> { - current_id: &'b mut spirv::Word, - global_variables: &'b HashMap, spirv::Word>, - global_type_check: &'b HashMap>, - special_registers: &'b mut SpecialRegistersMap, - variables: Vec, spirv::Word>>, - type_check: HashMap>, -} - -impl<'a, 'b> FnStringIdResolver<'a, 'b> { - fn finish(self) -> NumericIdResolver<'b> { - NumericIdResolver { - current_id: self.current_id, - global_type_check: self.global_type_check, - type_check: self.type_check, - special_registers: self.special_registers, - } - } - - fn start_block(&mut self) { - self.variables.push(HashMap::new()) - } - - fn end_block(&mut self) { - self.variables.pop(); - } - - fn get_id(&mut self, id: &str) -> Result { - for scope in self.variables.iter().rev() { - match scope.get(id) { - Some(id) => return Ok(*id), - None => continue, - } - } - match self.global_variables.get(id) { - Some(id) => Ok(*id), - None => { - let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?; - Ok(self.special_registers.get_or_add(self.current_id, sreg)) - } - } - } - - fn add_def( - &mut self, - id: &'a str, - typ: Option<(ast::Type, ast::StateSpace)>, - is_variable: bool, - ) -> spirv::Word { - let numeric_id = *self.current_id; - self.variables - .last_mut() - .unwrap() - .insert(Cow::Borrowed(id), numeric_id); - self.type_check.insert( - numeric_id, - typ.map(|(typ, space)| (typ, space, is_variable)), - ); - *self.current_id += 1; - numeric_id - } - - #[must_use] - fn add_defs( - &mut self, - base_id: &'a str, - count: u32, - typ: ast::Type, - state_space: ast::StateSpace, - is_variable: bool, - ) -> impl Iterator { - let numeric_id = *self.current_id; - for i in 0..count { - self.variables - .last_mut() - .unwrap() - .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check.insert( - numeric_id + i, - Some((typ.clone(), state_space, is_variable)), - ); - } - *self.current_id += count; - (0..count).into_iter().map(move |i| i + numeric_id) - } -} - -struct NumericIdResolver<'b> { - current_id: &'b mut spirv::Word, - global_type_check: &'b HashMap>, - type_check: HashMap>, - special_registers: &'b mut SpecialRegistersMap, -} - -impl<'b> NumericIdResolver<'b> { - fn finish(self) -> MutableNumericIdResolver<'b> { - MutableNumericIdResolver { base: self } - } - - fn get_typed( - &self, - id: spirv::Word, - ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { - match self.type_check.get(&id) { - Some(Some(x)) => Ok(x.clone()), - Some(None) => Err(TranslateError::UntypedSymbol), - None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), - None => match self.global_type_check.get(&id) { - Some(Some(result)) => Ok(result.clone()), - Some(None) | None => Err(TranslateError::UntypedSymbol), - }, - }, - } - } - - // This is for identifiers which will be emitted later as OpVariable - // They are candidates for insertion of LoadVar/StoreVar - fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word { - let new_id = *self.current_id; - self.type_check - .insert(new_id, Some((typ, state_space, true))); - *self.current_id += 1; - new_id - } - - fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word { - let new_id = *self.current_id; - self.type_check - .insert(new_id, typ.map(|(t, space)| (t, space, false))); - *self.current_id += 1; - new_id - } -} - -struct MutableNumericIdResolver<'b> { - base: NumericIdResolver<'b>, -} - -impl<'b> MutableNumericIdResolver<'b> { - fn unmut(self) -> NumericIdResolver<'b> { - self.base - } - - fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> { - self.base.get_typed(id).map(|(t, space, _)| (t, space)) - } - - fn register_intermediate( - &mut self, - typ: ast::Type, - state_space: ast::StateSpace, - ) -> spirv::Word { - self.base.register_intermediate(Some((typ, state_space))) - } -} - -struct FunctionPointerDetails { - dst: spirv::Word, - src: spirv::Word, -} - -impl, U: ArgParamsEx> Visitable - for FunctionPointerDetails -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::FunctionPointer(FunctionPointerDetails { - dst: visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::U64), - ast::StateSpace::Reg, - )), - )?, - src: visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - None, - )?, - })) - } -} - -enum Statement { - Label(u32), - Variable(ast::Variable), - Instruction(I), - // SPIR-V compatible replacement for PTX predicates - Conditional(BrachCondition), - Call(ResolvedCall

), - LoadVar(LoadVarDetails), - StoreVar(StoreVarDetails), - Conversion(ImplicitConversion), - Constant(ConstantDefinition), - RetValue(ast::RetData, spirv::Word), - PtrAccess(PtrAccess

), - RepackVector(RepackVectorDetails), - FunctionPointer(FunctionPointerDetails), -} - -impl ExpandedStatement { - fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement { - match self { - Statement::Label(id) => Statement::Label(f(id, false)), - Statement::Variable(mut var) => { - var.name = f(var.name, true); - Statement::Variable(var) - } - Statement::Instruction(inst) => inst - .visit(&mut |arg: ArgumentDescriptor<_>, - _: Option<(&ast::Type, ast::StateSpace)>| { - Ok(f(arg.op, arg.is_dst)) - }) - .unwrap(), - Statement::LoadVar(mut details) => { - details.arg.dst = f(details.arg.dst, true); - details.arg.src = f(details.arg.src, false); - Statement::LoadVar(details) - } - Statement::StoreVar(mut details) => { - details.arg.src1 = f(details.arg.src1, false); - details.arg.src2 = f(details.arg.src2, false); - Statement::StoreVar(details) - } - Statement::Call(mut call) => { - for (id, _, space) in call.return_arguments.iter_mut() { - let is_dst = match space { - ast::StateSpace::Reg => true, - ast::StateSpace::Param => false, - ast::StateSpace::Shared => false, - _ => todo!(), - }; - *id = f(*id, is_dst); - } - call.name = f(call.name, false); - for (id, _, _) in call.input_arguments.iter_mut() { - *id = f(*id, false); - } - Statement::Call(call) - } - Statement::Conditional(mut conditional) => { - conditional.predicate = f(conditional.predicate, false); - conditional.if_true = f(conditional.if_true, false); - conditional.if_false = f(conditional.if_false, false); - Statement::Conditional(conditional) - } - Statement::Conversion(mut conv) => { - conv.dst = f(conv.dst, true); - conv.src = f(conv.src, false); - Statement::Conversion(conv) - } - Statement::Constant(mut constant) => { - constant.dst = f(constant.dst, true); - Statement::Constant(constant) - } - Statement::RetValue(data, id) => { - let id = f(id, false); - Statement::RetValue(data, id) - } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src: constant_src, - }) => { - let dst = f(dst, true); - let ptr_src = f(ptr_src, false); - let constant_src = f(constant_src, false); - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src: constant_src, - }) - } - Statement::RepackVector(repack) => { - let packed = f(repack.packed, !repack.is_extract); - let unpacked = repack - .unpacked - .iter() - .map(|id| f(*id, repack.is_extract)) - .collect(); - Statement::RepackVector(RepackVectorDetails { - packed, - unpacked, - ..repack - }) - } - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { - Statement::FunctionPointer(FunctionPointerDetails { - dst: f(dst, true), - src: f(src, false), - }) - } - } - } -} - -struct LoadVarDetails { - arg: ast::Arg2, - typ: ast::Type, - state_space: ast::StateSpace, - // (index, vector_width) - // HACK ALERT - // For some reason IGC explodes when you try to load from builtin vectors - // using OpInBoundsAccessChain, the one true way to do it is to - // OpLoad+OpCompositeExtract - member_index: Option<(u8, Option)>, -} - -struct StoreVarDetails { - arg: ast::Arg2St, - typ: ast::Type, - member_index: Option, -} - -struct RepackVectorDetails { - is_extract: bool, - typ: ast::ScalarType, - packed: spirv::Word, - unpacked: Vec, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, -} - -impl RepackVectorDetails { - fn map< - From: ArgParamsEx, - To: ArgParamsEx, - V: ArgumentMapVisitor, - >( - self, - visitor: &mut V, - ) -> Result { - let scalar = visitor.id( - ArgumentDescriptor { - op: self.packed, - is_dst: !self.is_extract, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Vector(self.typ, self.unpacked.len() as u8), - ast::StateSpace::Reg, - )), - )?; - let scalar_type = self.typ; - let is_extract = self.is_extract; - let non_default_implicit_conversion = self.non_default_implicit_conversion; - let vector = self - .unpacked - .into_iter() - .map(|id| { - visitor.id( - ArgumentDescriptor { - op: id, - is_dst: is_extract, - is_memory_access: false, - non_default_implicit_conversion, - }, - Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), - ) - }) - .collect::>()?; - Ok(RepackVectorDetails { - is_extract, - typ: self.typ, - packed: scalar, - unpacked: vector, - non_default_implicit_conversion, - }) - } -} - -impl, U: ArgParamsEx> Visitable - for RepackVectorDetails -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) - } -} - -struct ResolvedCall { - pub uniform: bool, - pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>, - pub name: P::Id, - pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>, -} - -impl ResolvedCall { - fn cast>(self) -> ResolvedCall { - ResolvedCall { - uniform: self.uniform, - return_arguments: self.return_arguments, - name: self.name, - input_arguments: self.input_arguments, - } - } -} - -impl> ResolvedCall { - fn map, V: ArgumentMapVisitor>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let return_arguments = self - .return_arguments - .into_iter() - .map::, _>(|(id, typ, space)| { - let new_id = visitor.id( - ArgumentDescriptor { - op: id, - is_dst: space != ast::StateSpace::Param, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&typ, space)), - )?; - Ok((new_id, typ, space)) - }) - .collect::, _>>()?; - let func = visitor.id( - ArgumentDescriptor { - op: self.name, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - None, - )?; - let input_arguments = self - .input_arguments - .into_iter() - .map::, _>(|(id, typ, space)| { - let new_id = visitor.operand( - ArgumentDescriptor { - op: id, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &typ, - space, - )?; - Ok((new_id, typ, space)) - }) - .collect::, _>>()?; - Ok(ResolvedCall { - uniform: self.uniform, - return_arguments, - name: func, - input_arguments, - }) - } -} - -impl, U: ArgParamsEx> Visitable - for ResolvedCall -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::Call(self.map(visitor)?)) - } -} - -impl> PtrAccess

{ - fn map, V: ArgumentMapVisitor>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let new_dst = visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.underlying_type, self.state_space)), - )?; - let new_ptr_src = visitor.id( - ArgumentDescriptor { - op: self.ptr_src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.underlying_type, self.state_space)), - )?; - let new_constant_src = visitor.operand( - ArgumentDescriptor { - op: self.offset_src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::S64), - ast::StateSpace::Reg, - )?; - Ok(PtrAccess { - underlying_type: self.underlying_type, - state_space: self.state_space, - dst: new_dst, - ptr_src: new_ptr_src, - offset_src: new_constant_src, - }) - } -} - -impl, U: ArgParamsEx> Visitable - for PtrAccess -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::PtrAccess(self.map(visitor)?)) - } -} - -pub trait ArgParamsEx: ast::ArgParams + Sized {} - -impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {} - -enum NormalizedArgParams {} - -impl ast::ArgParams for NormalizedArgParams { - type Id = spirv::Word; - type Operand = ast::Operand; -} - -impl ArgParamsEx for NormalizedArgParams {} - -type NormalizedStatement = Statement< - ( - Option>, - ast::Instruction, - ), - NormalizedArgParams, ->; - -type UnconditionalStatement = Statement, NormalizedArgParams>; - -enum TypedArgParams {} - -impl ast::ArgParams for TypedArgParams { - type Id = spirv::Word; - type Operand = TypedOperand; -} - -impl ArgParamsEx for TypedArgParams {} - -#[derive(Copy, Clone)] -enum TypedOperand { - Reg(spirv::Word), - RegOffset(spirv::Word, i32), - Imm(ast::ImmediateValue), - VecMember(spirv::Word, u8), -} - -impl TypedOperand { - fn upcast(self) -> ast::Operand { - match self { - TypedOperand::Reg(reg) => ast::Operand::Reg(reg), - TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), - TypedOperand::Imm(x) => ast::Operand::Imm(x), - TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx), - } - } -} - -type TypedStatement = Statement, TypedArgParams>; - -enum ExpandedArgParams {} -type ExpandedStatement = Statement, ExpandedArgParams>; - -impl ast::ArgParams for ExpandedArgParams { - type Id = spirv::Word; - type Operand = spirv::Word; -} - -impl ArgParamsEx for ExpandedArgParams {} - -enum Directive<'input> { - Variable(ast::LinkingDirective, ast::Variable), - Method(Function<'input>), -} - -struct Function<'input> { - pub func_decl: Rc>>, - pub globals: Vec>, - pub body: Option>, - import_as: Option, - tuning: Vec, - linkage: ast::LinkingDirective, -} - -pub trait ArgumentMapVisitor { - fn id( - &mut self, - desc: ArgumentDescriptor, - typ: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result; - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result; -} - -impl ArgumentMapVisitor for T -where - T: FnMut( - ArgumentDescriptor, - Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - self(desc, Some((typ, state_space))) - } -} - -impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T -where - T: FnMut(&str) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor<&str>, - _: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc.op) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - Ok(match desc.op { - ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), - ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), - ast::Operand::Imm(imm) => ast::Operand::Imm(imm), - ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), - ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( - ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some((typ, state_space)))) - .collect::, _>>()?, - ), - }) - } -} - -pub struct ArgumentDescriptor { - op: Op, - is_dst: bool, - is_memory_access: bool, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, -} - -pub struct PtrAccess { - underlying_type: ast::Type, - state_space: ast::StateSpace, - dst: spirv::Word, - ptr_src: spirv::Word, - offset_src: P::Operand, -} - -impl ArgumentDescriptor { - fn new_op(&self, u: U) -> ArgumentDescriptor { - ArgumentDescriptor { - op: u, - is_dst: self.is_dst, - is_memory_access: self.is_memory_access, - non_default_implicit_conversion: self.non_default_implicit_conversion, - } - } -} - -impl ast::Instruction { - fn map>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - Ok(match self { - ast::Instruction::Abs(d, arg) => { - ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) - } - // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => return Err(error_unreachable()), - ast::Instruction::Ld(d, a) => { - let new_args = a.map(visitor, &d)?; - ast::Instruction::Ld(d, new_args) - } - ast::Instruction::Mov(d, a) => { - let mapped = a.map(visitor, &d)?; - ast::Instruction::Mov(d, mapped) - } - ast::Instruction::Mul(d, a) => { - let inst_type = d.get_type(); - let is_wide = d.is_wide(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?) - } - ast::Instruction::Add(d, a) => { - let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?) - } - ast::Instruction::Setp(d, a) => { - let inst_type = d.typ; - ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) - } - ast::Instruction::SetpBool(d, a) => { - let inst_type = d.typ; - ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) - } - ast::Instruction::Not(t, a) => { - ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?) - } - ast::Instruction::Cvt(d, a) => { - let (dst_t, src_t, int_to_int) = match &d { - ast::CvtDetails::FloatFromFloat(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::FloatFromInt(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::IntFromFloat(desc) => ((desc.dst, desc.src, false)), - ast::CvtDetails::IntFromInt(desc) => ((desc.dst, desc.src, true)), - }; - ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t, int_to_int)?) - } - ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?) - } - ast::Instruction::Shr(t, a) => { - ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) - } - ast::Instruction::St(d, a) => { - let new_args = a.map(visitor, &d)?; - ast::Instruction::St(d, new_args) - } - ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?), - ast::Instruction::Ret(d) => ast::Instruction::Ret(d), - ast::Instruction::Cvta(d, a) => { - let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?) - } - ast::Instruction::Mad(d, a) => { - let inst_type = d.get_type(); - let is_wide = d.is_wide(); - ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?) - } - ast::Instruction::Fma(d, a) => { - let inst_type = ast::Type::Scalar(d.typ); - ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?) - } - ast::Instruction::Or(t, a) => ast::Instruction::Or( - t, - a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, - ), - ast::Instruction::Sub(d, a) => { - let typ = d.get_type(); - ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Min(d, a) => { - let typ = d.get_type(); - ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Max(d, a) => { - let typ = d.get_type(); - ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?) - } - ast::Instruction::Rcp(d, a) => { - let typ = ast::Type::Scalar(if d.is_f64 { - ast::ScalarType::F64 - } else { - ast::ScalarType::F32 - }); - ast::Instruction::Rcp(d, a.map(visitor, &typ)?) - } - ast::Instruction::And(t, a) => ast::Instruction::And( - t, - a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, - ), - ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?), - ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?), - ast::Instruction::Atom(d, a) => { - ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?) - } - ast::Instruction::AtomCas(d, a) => { - ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?) - } - ast::Instruction::Div(d, a) => { - ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?) - } - ast::Instruction::Sqrt(d, a) => { - ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) - } - ast::Instruction::Rsqrt(d, a) => { - ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) - } - ast::Instruction::Neg(d, a) => { - ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?) - } - ast::Instruction::Sin { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Sin { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Cos { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Cos { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Lg2 { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Lg2 { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Ex2 { flush_to_zero, arg } => { - let typ = ast::Type::Scalar(ast::ScalarType::F32); - ast::Instruction::Ex2 { - flush_to_zero, - arg: arg.map(visitor, &typ)?, - } - } - ast::Instruction::Clz { typ, arg } => { - let dst_type = ast::Type::Scalar(ast::ScalarType::B32); - let src_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Clz { - typ, - arg: arg.map_different_types(visitor, &dst_type, &src_type)?, - } - } - ast::Instruction::Brev { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Brev { - typ, - arg: arg.map(visitor, &full_type)?, - } - } - ast::Instruction::Popc { typ, arg } => { - let dst_type = ast::Type::Scalar(ast::ScalarType::B32); - let src_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Popc { - typ, - arg: arg.map_different_types(visitor, &dst_type, &src_type)?, - } - } - ast::Instruction::Xor { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Xor { - typ, - arg: arg.map_non_shift(visitor, &full_type, false)?, - } - } - ast::Instruction::Bfe { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Bfe { - typ, - arg: arg.map_bfe(visitor, &full_type)?, - } - } - ast::Instruction::Bfi { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Bfi { - typ, - arg: arg.map_bfi(visitor, &full_type)?, - } - } - ast::Instruction::Rem { typ, arg } => { - let full_type = ast::Type::Scalar(typ.into()); - ast::Instruction::Rem { - typ, - arg: arg.map_non_shift(visitor, &full_type, false)?, - } - } - ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt { - control, - arg: arg.map_prmt(visitor)?, - }, - ast::Instruction::Activemask { arg } => ast::Instruction::Activemask { - arg: arg.map( - visitor, - true, - Some(( - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )), - )?, - }, - ast::Instruction::Membar { level } => ast::Instruction::Membar { level }, - }) - } -} - -impl Visitable for ast::Instruction { - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, U>, TranslateError> { - Ok(Statement::Instruction(self.map(visitor)?)) - } -} - -impl ImplicitConversion { - fn map< - T: ArgParamsEx, - U: ArgParamsEx, - V: ArgumentMapVisitor, - >( - self, - visitor: &mut V, - ) -> Result, U>, TranslateError> { - let new_dst = visitor.id( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.to_type, self.to_space)), - )?; - let new_src = visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some((&self.from_type, self.from_space)), - )?; - Ok(Statement::Conversion({ - ImplicitConversion { - src: new_src, - dst: new_dst, - ..self - } - })) - } -} - -impl, To: ArgParamsEx> Visitable - for ImplicitConversion -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError> { - Ok(self.map(visitor)?) - } -} - -impl ArgumentMapVisitor for T -where - T: FnMut( - ArgumentDescriptor, - Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, -{ - fn id( - &mut self, - desc: ArgumentDescriptor, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result { - self(desc, t) - } - - fn operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - state_space: ast::StateSpace, - ) -> Result { - Ok(match desc.op { - TypedOperand::Reg(id) => { - TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?) - } - TypedOperand::Imm(imm) => TypedOperand::Imm(imm), - TypedOperand::RegOffset(id, imm) => { - TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm) - } - TypedOperand::VecMember(reg, index) => { - let scalar_type = match typ { - ast::Type::Scalar(scalar_t) => *scalar_t, - _ => return Err(error_unreachable()), - }; - let vec_type = ast::Type::Vector(scalar_type, index + 1); - TypedOperand::VecMember( - self(desc.new_op(reg), Some((&vec_type, state_space)))?, - index, - ) - } - }) - } -} - -impl ast::Type { - fn widen(self) -> Result { - match self { - ast::Type::Scalar(scalar) => { - let kind = scalar.kind(); - let width = scalar.size_of(); - if (kind != ast::ScalarKind::Signed - && kind != ast::ScalarKind::Unsigned - && kind != ast::ScalarKind::Bit) - || (width == 8) - { - return Err(TranslateError::MismatchedType); - } - Ok(ast::Type::Scalar(ast::ScalarType::from_parts( - width * 2, - kind, - ))) - } - _ => Err(error_unreachable()), - } - } - - fn to_parts(&self) -> TypeParts { - match self { - ast::Type::Scalar(scalar) => TypeParts { - kind: TypeKind::Scalar, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - ast::Type::Vector(scalar, components) => TypeParts { - kind: TypeKind::Vector, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*components as u32], - }, - ast::Type::Array(scalar, components) => TypeParts { - kind: TypeKind::Array, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - }, - ast::Type::Pointer(scalar, space) => TypeParts { - kind: TypeKind::Pointer, - state_space: *space, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - } - } - - fn from_parts(t: TypeParts) -> Self { - match t.kind { - TypeKind::Scalar => { - ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)) - } - TypeKind::Vector => ast::Type::Vector( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components[0] as u8, - ), - TypeKind::Array => ast::Type::Array( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components, - ), - TypeKind::Pointer => ast::Type::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.state_space, - ), - } - } - - pub fn size_of(&self) -> usize { - match self { - ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(typ, len) => len - .iter() - .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(..) => mem::size_of::(), - } - } -} - -#[derive(Eq, PartialEq, Clone)] -struct TypeParts { - kind: TypeKind, - scalar_kind: ast::ScalarKind, - width: u8, - state_space: ast::StateSpace, - components: Vec, -} - -#[derive(Eq, PartialEq, Copy, Clone)] -enum TypeKind { - Scalar, - Vector, - Array, - Pointer, -} - -impl ast::Instruction { - fn jump_target(&self) -> Option { - match self { - ast::Instruction::Bra(_, a) => Some(a.src), - _ => None, - } - } - - // .wide instructions don't support ftz, so it's enough to just look at the - // type declared by the instruction - fn flush_to_zero(&self) -> Option<(bool, u8)> { - match self { - ast::Instruction::Ld(_, _) => None, - ast::Instruction::St(_, _) => None, - ast::Instruction::Mov(_, _) => None, - ast::Instruction::Not(_, _) => None, - ast::Instruction::Bra(_, _) => None, - ast::Instruction::Shl(_, _) => None, - ast::Instruction::Shr(_, _) => None, - ast::Instruction::Ret(_) => None, - ast::Instruction::Call(_) => None, - ast::Instruction::Or(_, _) => None, - ast::Instruction::And(_, _) => None, - ast::Instruction::Cvta(_, _) => None, - ast::Instruction::Selp(_, _) => None, - ast::Instruction::Bar(_, _) => None, - ast::Instruction::Atom(_, _) => None, - ast::Instruction::AtomCas(_, _) => None, - ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, - ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, - ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, - ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None, - ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None, - ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None, - ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None, - ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None, - ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None, - ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None, - ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None, - ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None, - ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None, - ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None, - ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None, - ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None, - ast::Instruction::Clz { .. } => None, - ast::Instruction::Brev { .. } => None, - ast::Instruction::Popc { .. } => None, - ast::Instruction::Xor { .. } => None, - ast::Instruction::Bfe { .. } => None, - ast::Instruction::Bfi { .. } => None, - ast::Instruction::Rem { .. } => None, - ast::Instruction::Prmt { .. } => None, - ast::Instruction::Activemask { .. } => None, - ast::Instruction::Membar { .. } => None, - ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) - | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) - | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) - | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), - ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())), - ast::Instruction::Setp(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::SetpBool(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Abs(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _) - | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), - ast::Instruction::Rcp(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })), - // Modifier .ftz can only be specified when either .dtype or .atype - // is .f32 and applies only to single precision (.f32) inputs and results. - ast::Instruction::Cvt( - ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }), - _, - ) - | ast::Instruction::Cvt( - ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }), - _, - ) => flush_to_zero.map(|ftz| (ftz, 4)), - ast::Instruction::Div(ast::DivDetails::Float(details), _) => details - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), - ast::Instruction::Sqrt(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), - ast::Instruction::Rsqrt(details, _) => Some(( - details.flush_to_zero, - ast::ScalarType::from(details.typ).size_of(), - )), - ast::Instruction::Neg(details, _) => details - .flush_to_zero - .map(|ftz| (ftz, details.typ.size_of())), - ast::Instruction::Sin { flush_to_zero, .. } - | ast::Instruction::Cos { flush_to_zero, .. } - | ast::Instruction::Lg2 { flush_to_zero, .. } - | ast::Instruction::Ex2 { flush_to_zero, .. } => { - Some((*flush_to_zero, mem::size_of::() as u8)) - } - } - } -} - -type Arg2 = ast::Arg2; -type Arg2St = ast::Arg2St; - -struct ConstantDefinition { - pub dst: spirv::Word, - pub typ: ast::ScalarType, - pub value: ast::ImmediateValue, -} - -struct BrachCondition { - predicate: spirv::Word, - if_true: spirv::Word, - if_false: spirv::Word, -} - -impl, To: ArgParamsEx> Visitable - for BrachCondition -{ - fn visit( - self, - visitor: &mut impl ArgumentMapVisitor, - ) -> Result, To>, TranslateError> { - let predicate = visitor.id( - ArgumentDescriptor { - op: self.predicate, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let if_true = self.if_true; - let if_false = self.if_false; - Ok(Statement::Conditional(BrachCondition { - predicate, - if_true, - if_false, - })) - } -} - -#[derive(Clone)] -struct ImplicitConversion { - src: spirv::Word, - dst: spirv::Word, - from_type: ast::Type, - to_type: ast::Type, - from_space: ast::StateSpace, - to_space: ast::StateSpace, - kind: ConversionKind, -} - -#[derive(PartialEq, Clone)] -enum ConversionKind { - Default, - // zero-extend/chop/bitcast depending on types - SignExtend, - BitToPtr, - PtrToPtr, - AddressOf, -} - -impl ast::PredAt { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - let new_label = f(self.label)?; - Ok(ast::PredAt { - not: self.not, - label: new_label, - }) - } -} - -impl<'a> ast::Instruction> { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - match self { - ast::Instruction::Call(call) => { - let call_inst = ast::CallInst { - uniform: call.uniform, - ret_params: call - .ret_params - .into_iter() - .map(|p| f(p)) - .collect::>()?, - func: f(call.func)?, - param_list: call - .param_list - .into_iter() - .map(|p| p.map_variable(f)) - .collect::>()?, - }; - Ok(ast::Instruction::Call(call_inst)) - } - i => i.map(f), - } - } -} - -impl ast::Arg1 { - fn map>( - self, - visitor: &mut V, - is_dst: bool, - t: Option<(&ast::Type, ast::StateSpace)>, - ) -> Result, TranslateError> { - let new_src = visitor.id( - ArgumentDescriptor { - op: self.src, - is_dst, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - )?; - Ok(ast::Arg1 { src: new_src }) - } -} - -impl ast::Arg1Bar { - fn map>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let new_src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg1Bar { src: new_src }) - } -} - -impl ast::Arg2 { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let new_dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let new_src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { - dst: new_dst, - src: new_src, - }) - } - - fn map_cvt>( - self, - visitor: &mut V, - dst_t: ast::ScalarType, - src_t: ast::ScalarType, - is_int_to_int: bool, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: if is_int_to_int { - Some(should_convert_relaxed_dst_wrapper) - } else { - None - }, - }, - &ast::Type::Scalar(dst_t), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: if is_int_to_int { - Some(should_convert_relaxed_src_wrapper) - } else { - None - }, - }, - &ast::Type::Scalar(src_t), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { dst, src }) - } - - fn map_different_types>( - self, - visitor: &mut V, - dst_t: &ast::Type, - src_t: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - dst_t, - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - src_t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2 { dst, src }) - } -} - -impl ast::Arg2Ld { - fn map>( - self, - visitor: &mut V, - details: &ast::LdDetails, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper), - }, - &ast::Type::from(details.typ.clone()), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &details.typ, - details.state_space, - )?; - Ok(ast::Arg2Ld { dst, src }) - } -} - -impl ast::Arg2St { - fn map>( - self, - visitor: &mut V, - details: &ast::StData, - ) -> Result, TranslateError> { - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &details.typ, - details.state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper), - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2St { src1, src2 }) - } -} - -impl ast::Arg2Mov { - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: Some(implicit_conversion_mov), - }, - &details.typ.clone().into(), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg2Mov { dst, src }) - } -} - -impl ast::Arg3 { - fn map_non_shift>( - self, - visitor: &mut V, - typ: &ast::Type, - is_wide: bool, - ) -> Result, TranslateError> { - let wide_type = if is_wide { - Some(typ.clone().widen()?) - } else { - None - }; - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - wide_type.as_ref().unwrap_or(typ), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_shift>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_atom>( - self, - visitor: &mut V, - t: ast::ScalarType, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - let scalar_type = ast::ScalarType::from(t); - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } - - fn map_prmt>( - self, - visitor: &mut V, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::B32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg3 { dst, src1, src2 }) - } -} - -impl ast::Arg4 { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - is_wide: bool, - ) -> Result, TranslateError> { - let wide_type = if is_wide { - Some(t.clone().widen()?) - } else { - None - }; - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - wide_type.as_ref().unwrap_or(t), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_selp>( - self, - visitor: &mut V, - t: ast::ScalarType, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(t.into()), - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_atom>( - self, - visitor: &mut V, - t: ast::ScalarType, - state_space: ast::StateSpace, - ) -> Result, TranslateError> { - let scalar_type = ast::ScalarType::from(t); - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: true, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - state_space, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(scalar_type), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } - - fn map_bfe>( - self, - visitor: &mut V, - typ: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - typ, - ast::StateSpace::Reg, - )?; - let u32_type = ast::Type::Scalar(ast::ScalarType::U32); - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &u32_type, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &u32_type, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4 { - dst, - src1, - src2, - src3, - }) - } -} - -impl ast::Arg4Setp { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst1 = visitor.id( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let dst2 = self - .dst2 - .map(|dst2| { - visitor.id( - ArgumentDescriptor { - op: dst2, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - ) - }) - .transpose()?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - Ok(ast::Arg4Setp { - dst1, - dst2, - src1, - src2, - }) - } -} - -impl ast::Arg5 { - fn map_bfi>( - self, - visitor: &mut V, - base_type: &ast::Type, - ) -> Result, TranslateError> { - let dst = visitor.operand( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - base_type, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - let src4 = visitor.operand( - ArgumentDescriptor { - op: self.src4, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg5 { - dst, - src1, - src2, - src3, - src4, - }) - } -} - -impl ast::Arg5Setp { - fn map>( - self, - visitor: &mut V, - t: &ast::Type, - ) -> Result, TranslateError> { - let dst1 = visitor.id( - ArgumentDescriptor { - op: self.dst1, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - )?; - let dst2 = self - .dst2 - .map(|dst2| { - visitor.id( - ArgumentDescriptor { - op: dst2, - is_dst: true, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - Some(( - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )), - ) - }) - .transpose()?; - let src1 = visitor.operand( - ArgumentDescriptor { - op: self.src1, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src2 = visitor.operand( - ArgumentDescriptor { - op: self.src2, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - t, - ast::StateSpace::Reg, - )?; - let src3 = visitor.operand( - ArgumentDescriptor { - op: self.src3, - is_dst: false, - is_memory_access: false, - non_default_implicit_conversion: None, - }, - &ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - )?; - Ok(ast::Arg5Setp { - dst1, - dst2, - src1, - src2, - src3, - }) - } -} - -impl ast::Operand { - fn map_variable Result>( - self, - f: &mut F, - ) -> Result, TranslateError> { - Ok(match self { - ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), - ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), - ast::Operand::Imm(x) => ast::Operand::Imm(x), - ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), - ast::Operand::VecPack(vec) => { - ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) - } - }) - } -} - -impl ast::Operand { - fn unwrap_reg(&self) -> Result { - match self { - ast::Operand::Reg(reg) => Ok(*reg), - _ => Err(error_unreachable()), - } - } -} - -impl ast::ScalarType { - fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { - match kind { - ast::ScalarKind::Float => match width { - 2 => ast::ScalarType::F16, - 4 => ast::ScalarType::F32, - 8 => ast::ScalarType::F64, - _ => unreachable!(), - }, - ast::ScalarKind::Bit => match width { - 1 => ast::ScalarType::B8, - 2 => ast::ScalarType::B16, - 4 => ast::ScalarType::B32, - 8 => ast::ScalarType::B64, - _ => unreachable!(), - }, - ast::ScalarKind::Signed => match width { - 1 => ast::ScalarType::S8, - 2 => ast::ScalarType::S16, - 4 => ast::ScalarType::S32, - 8 => ast::ScalarType::S64, - _ => unreachable!(), - }, - ast::ScalarKind::Unsigned => match width { - 1 => ast::ScalarType::U8, - 2 => ast::ScalarType::U16, - 4 => ast::ScalarType::U32, - 8 => ast::ScalarType::U64, - _ => unreachable!(), - }, - ast::ScalarKind::Float2 => match width { - 4 => ast::ScalarType::F16x2, - _ => unreachable!(), - }, - ast::ScalarKind::Pred => ast::ScalarType::Pred, - } - } -} - -impl ast::ArithDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::ArithDetails::Unsigned(t) => (*t).into(), - ast::ArithDetails::Signed(d) => d.typ.into(), - ast::ArithDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::MulDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::MulDetails::Unsigned(d) => d.typ.into(), - ast::MulDetails::Signed(d) => d.typ.into(), - ast::MulDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::MinMaxDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::MinMaxDetails::Signed(t) => (*t).into(), - ast::MinMaxDetails::Unsigned(t) => (*t).into(), - ast::MinMaxDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::DivDetails { - fn get_type(&self) -> ast::Type { - ast::Type::Scalar(match self { - ast::DivDetails::Unsigned(t) => (*t).into(), - ast::DivDetails::Signed(t) => (*t).into(), - ast::DivDetails::Float(d) => d.typ.into(), - }) - } -} - -impl ast::AtomInnerDetails { - fn get_type(&self) -> ast::ScalarType { - match self { - ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(), - ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(), - } - } -} - -impl ast::StateSpace { - fn to_spirv(self) -> spirv::StorageClass { - match self { - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Local => spirv::StorageClass::Function, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Param => spirv::StorageClass::Function, - ast::StateSpace::Reg => spirv::StorageClass::Function, - ast::StateSpace::Sreg => spirv::StorageClass::Input, - } - } - - fn is_compatible(self, other: ast::StateSpace) -> bool { - self == other - || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg - || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg - } - - fn coerces_to_generic(self) -> bool { - match self { - ast::StateSpace::Global - | ast::StateSpace::Const - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Reg - | ast::StateSpace::Param - | ast::StateSpace::Generic - | ast::StateSpace::Sreg => false, - } - } - - fn is_addressable(self) -> bool { - match self { - ast::StateSpace::Const - | ast::StateSpace::Generic - | ast::StateSpace::Global - | ast::StateSpace::Local - | ast::StateSpace::Shared => true, - ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, - } - } -} - -impl ast::Operand { - fn underlying_register(&self) -> Option<&T> { - match self { - ast::Operand::Reg(r) - | ast::Operand::RegOffset(r, _) - | ast::Operand::VecMember(r, _) => Some(r), - ast::Operand::Imm(_) | ast::Operand::VecPack(..) => None, - } - } -} - -impl ast::MulDetails { - fn is_wide(&self) -> bool { - match self { - ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide, - ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide, - ast::MulDetails::Float(_) => false, - } - } -} - -impl ast::MemScope { - fn to_spirv(self) -> spirv::Scope { - match self { - ast::MemScope::Cta => spirv::Scope::Workgroup, - ast::MemScope::Gpu => spirv::Scope::Device, - ast::MemScope::Sys => spirv::Scope::CrossDevice, - } - } -} - -impl ast::AtomSemantics { - fn to_spirv(self) -> spirv::MemorySemantics { - match self { - ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, - ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, - ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, - ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE, - } - } -} - -fn default_implicit_conversion( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if !instruction_space.is_compatible(operand_space) { - default_implicit_conversion_space( - (operand_space, operand_type), - (instruction_space, instruction_type), - ) - } else if instruction_type != operand_type { - default_implicit_conversion_type(instruction_space, operand_type, instruction_type) - } else { - Ok(None) - } -} - -// Space is different -fn default_implicit_conversion_space( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic()) - || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic()) - { - Ok(Some(ConversionKind::PtrToPtr)) - } else if operand_space.is_compatible(ast::StateSpace::Reg) { - match operand_type { - ast::Type::Pointer(operand_ptr_type, operand_ptr_space) - if *operand_ptr_space == instruction_space => - { - if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { - Ok(Some(ConversionKind::PtrToPtr)) - } else { - Ok(None) - } - } - // TODO: 32 bit - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { - ast::StateSpace::Global - | ast::StateSpace::Generic - | ast::StateSpace::Const - | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), - _ => Err(TranslateError::MismatchedType), - }, - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { - ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { - Ok(Some(ConversionKind::BitToPtr)) - } - _ => Err(TranslateError::MismatchedType), - }, - _ => Err(TranslateError::MismatchedType), - } - } else if instruction_space.is_compatible(ast::StateSpace::Reg) { - match instruction_type { - ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) - if operand_space == *instruction_ptr_space => - { - if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { - Ok(Some(ConversionKind::PtrToPtr)) - } else { - Ok(None) - } - } - _ => Err(TranslateError::MismatchedType), - } - } else { - Err(TranslateError::MismatchedType) - } -} - -// Space is same, but type is different -fn default_implicit_conversion_type( - space: ast::StateSpace, - operand_type: &ast::Type, - instruction_type: &ast::Type, -) -> Result, TranslateError> { - if space.is_compatible(ast::StateSpace::Reg) { - if should_bitcast(instruction_type, operand_type) { - Ok(Some(ConversionKind::Default)) - } else { - Err(TranslateError::MismatchedType) - } - } else { - Ok(Some(ConversionKind::PtrToPtr)) - } -} - -fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { - match (instr, operand) { - (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { - if inst.size_of() != operand.size_of() { - return false; - } - match inst.kind() { - ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, - ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, - ast::ScalarKind::Signed => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Unsigned - } - ast::ScalarKind::Unsigned => { - operand.kind() == ast::ScalarKind::Bit - || operand.kind() == ast::ScalarKind::Signed - } - ast::ScalarKind::Float2 => false, - ast::ScalarKind::Pred => false, - } - } - (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) - | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { - should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) - } - _ => false, - } -} - -fn implicit_conversion_mov( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - // instruction_space is always reg - if operand_space.is_compatible(ast::StateSpace::Reg) { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand_type, instruction_type) - { - if scalar.kind() == ast::ScalarKind::Bit - && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) - { - return Ok(Some(ConversionKind::Default)); - } - } - // TODO: verify .params addressability: - // * kernel arg - // * func arg - // * variable - } else if operand_space.is_addressable() { - return Ok(Some(ConversionKind::AddressOf)); - } - default_implicit_conversion( - (operand_space, operand_type), - (instruction_space, instruction_type), - ) -} - -fn should_convert_relaxed_src_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if !operand_space.is_compatible(instruction_space) { - return Err(TranslateError::MismatchedType); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_src(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands -fn should_convert_relaxed_src( - src_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if src_type == instr_type { - return None; - } - match (src_type, instr_type) { - (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= src_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= src_type.size_of() - && src_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float2 => todo!(), - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { - should_convert_relaxed_src( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} - -fn should_convert_relaxed_dst_wrapper( - (operand_space, operand_type): (ast::StateSpace, &ast::Type), - (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), -) -> Result, TranslateError> { - if !operand_space.is_compatible(instruction_space) { - return Err(TranslateError::MismatchedType); - } - if operand_type == instruction_type { - return Ok(None); - } - match should_convert_relaxed_dst(operand_type, instruction_type) { - conv @ Some(_) => Ok(conv), - None => Err(TranslateError::MismatchedType), - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands -fn should_convert_relaxed_dst( - dst_type: &ast::Type, - instr_type: &ast::Type, -) -> Option { - if dst_type == instr_type { - return None; - } - match (dst_type, instr_type) { - (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { - ast::ScalarKind::Bit => { - if instr_type.size_of() <= dst_type.size_of() { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Signed => { - if dst_type.kind() != ast::ScalarKind::Float { - if instr_type.size_of() == dst_type.size_of() { - Some(ConversionKind::Default) - } else if instr_type.size_of() < dst_type.size_of() { - Some(ConversionKind::SignExtend) - } else { - None - } - } else { - None - } - } - ast::ScalarKind::Unsigned => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() != ast::ScalarKind::Float - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float => { - if instr_type.size_of() <= dst_type.size_of() - && dst_type.kind() == ast::ScalarKind::Bit - { - Some(ConversionKind::Default) - } else { - None - } - } - ast::ScalarKind::Float2 => todo!(), - ast::ScalarKind::Pred => None, - }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { - should_convert_relaxed_dst( - &ast::Type::Scalar(*dst_type), - &ast::Type::Scalar(*instr_type), - ) - } - _ => None, - } -} - -impl<'a> ast::MethodDeclaration<'a, &'a str> { - fn name(&self) -> &'a str { - match self.name { - ast::MethodName::Kernel(name) => name, - ast::MethodName::Func(name) => name, - } - } -} - -impl<'a> ast::MethodDeclaration<'a, spirv::Word> { - fn effective_input_arguments(&self) -> impl Iterator + '_ { - let is_kernel = self.name.is_kernel(); - self.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) - } -} - -impl<'input, ID> ast::MethodName<'input, ID> { - fn is_kernel(&self) -> bool { - match self { - ast::MethodName::Kernel(..) => true, - ast::MethodName::Func(..) => false, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ast; - - static SCALAR_TYPES: [ast::ScalarType; 15] = [ - ast::ScalarType::B8, - ast::ScalarType::B16, - ast::ScalarType::B32, - ast::ScalarType::B64, - ast::ScalarType::S8, - ast::ScalarType::S16, - ast::ScalarType::S32, - ast::ScalarType::S64, - ast::ScalarType::U8, - ast::ScalarType::U16, - ast::ScalarType::U32, - ast::ScalarType::U64, - ast::ScalarType::F16, - ast::ScalarType::F32, - ast::ScalarType::F64, - ]; - - static RELAXED_SRC_CONVERSION_TABLE: &'static str = - "b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop - b16 inv - chop chop inv - chop chop inv - chop chop - chop chop - b32 inv inv - chop inv inv - chop inv inv - chop inv - chop - b64 inv inv inv - inv inv inv - inv inv inv - inv inv - - s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv - s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv - s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv - s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv - u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv - u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv - u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv - u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv - f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv - f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv - f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; - - static RELAXED_DST_CONVERSION_TABLE: &'static str = - "b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext - b16 inv - zext zext inv - zext zext inv - zext zext - zext zext - b32 inv inv - zext inv inv - zext inv inv - zext inv - zext - b64 inv inv inv - inv inv inv - inv inv inv - inv inv - - s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv - s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv - s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv - s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv - u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv - u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv - u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv - u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv - f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv - f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv - f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; - - fn table_entry_to_conversion(entry: &'static str) -> Option { - match entry { - "-" => Some(ConversionKind::Default), - "inv" => None, - "zext" => Some(ConversionKind::Default), - "chop" => Some(ConversionKind::Default), - "sext" => Some(ConversionKind::SignExtend), - _ => unreachable!(), - } - } - - fn parse_conversion_table(table: &'static str) -> Vec>> { - table - .lines() - .map(|line| { - line.split_ascii_whitespace() - .skip(1) - .map(table_entry_to_conversion) - .collect::>() - }) - .collect::>() - } - - fn assert_conversion_table Option>( - table: &'static str, - f: F, - ) { - let conv_table = parse_conversion_table(table); - for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() { - for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() { - let conversion = f( - &ast::Type::Scalar(*op_type), - &ast::Type::Scalar(*instr_type), - ); - if instr_idx == op_idx { - assert!(conversion == None); - } else { - assert!(conversion == conv_table[instr_idx][op_idx]); - } - } - } - } - - #[test] - fn should_convert_relaxed_src_all_combinations() { - assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src); - } - - #[test] - fn should_convert_relaxed_dst_all_combinations() { - assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst); - } -} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 65c624e..f0d3fbe 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; -use std::{cmp::Ordering, num::NonZeroU8}; +use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; pub enum Statement { Label(P::Ident), @@ -806,6 +806,32 @@ impl Type { None => Self::maybe_vector_parsed(prefix, scalar), } } + + pub fn layout(&self) -> Layout { + match self { + Type::Scalar(type_) => type_.layout(), + Type::Vector(elements, scalar_type) => { + let scalar_layout = scalar_type.layout(); + unsafe { + Layout::from_size_align_unchecked( + scalar_layout.size() * *elements as usize, + scalar_layout.align() * *elements as usize, + ) + } + } + Type::Array(non_zero, scalar, vec) => { + let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout(); + let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0); + unsafe { + Layout::from_size_align_unchecked( + element_layout.size() * (len as usize), + element_layout.align(), + ) + } + } + Type::Pointer(..) => Layout::new::(), + } + } } impl ScalarType { @@ -831,6 +857,31 @@ impl ScalarType { } } + pub fn layout(self) -> Layout { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::(), + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => Layout::new::(), + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => Layout::new::(), + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => { + Layout::new::() + } + ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) }, + // Close enough + ScalarType::Pred => Layout::new::(), + } + } + pub fn kind(self) -> ScalarKind { match self { ScalarType::U8 => ScalarKind::Unsigned, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index fee11aa..b49503b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1349,10 +1349,10 @@ impl std::error::Error for TokenError {} // * After parsing, each instruction needs to do some early validation and generate a specific, // strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but // there can be multiple different code emitter backends -// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// * Most importantly, instruction modifiers can come in aby order, so e.g. both // `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes // classic parsing generators fail: if we tried to generate parsing rules that cover every possible -// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang // will always emit modifiers in the correct order, but people who write inline assembly usually // get it wrong (even first party developers) // @@ -1398,7 +1398,7 @@ impl std::error::Error for TokenError {} // * List of rules. They are associated with the preceding patterns (until different opcode or // different rules). Rules are used to resolve modifiers. There are two types of rules: // * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we -// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, // FoobarEnum::DotC appropriately // * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will // emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors @@ -3233,36 +3233,42 @@ mod tests { #[test] fn sm_11() { let tokens = Token::lexer(".target sm_11") - .collect::, ()>>() + .collect::, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); + assert_eq!(errors.len(), 0); } #[test] fn sm_90a() { let tokens = Token::lexer(".target sm_90a") - .collect::, ()>>() + .collect::, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + assert_eq!(errors.len(), 0); } #[test] fn sm_90ab() { let tokens = Token::lexer(".target sm_90ab") - .collect::, ()>>() + .collect::, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert!(target.parse(stream).is_err()); + assert_eq!(errors.len(), 0); } } -- cgit v1.2.3