aboutsummaryrefslogtreecommitdiffhomepage
path: root/notcuda_inject/src/win.rs
blob: ec57ffb57499ac7cd2faa04b06ef8ac66a838740 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#![allow(non_snake_case)]

use std::error;
use std::fmt;
use std::ptr;

mod c {
    use std::ffi::c_void;
    use std::os::raw::c_ulong;

    pub type DWORD = c_ulong;
    pub type HANDLE = LPVOID;
    pub type LPVOID = *mut c_void;
    pub type HINSTANCE = HANDLE;
    pub type HMODULE = HINSTANCE;
    pub type WCHAR = u16;
    pub type LPCWSTR = *const WCHAR;
    pub type LPWSTR = *mut WCHAR;

    pub const FACILITY_NT_BIT: DWORD = 0x1000_0000;
    pub const FORMAT_MESSAGE_FROM_HMODULE: DWORD = 0x00000800;
    pub const FORMAT_MESSAGE_FROM_SYSTEM: DWORD = 0x00001000;
    pub const FORMAT_MESSAGE_IGNORE_INSERTS: DWORD = 0x00000200;

    extern "system" {
        pub fn GetLastError() -> DWORD;
        pub fn GetModuleHandleW(lpModuleName: LPCWSTR) -> HMODULE;
        pub fn FormatMessageW(
            flags: DWORD,
            lpSrc: LPVOID,
            msgId: DWORD,
            langId: DWORD,
            buf: LPWSTR,
            nsize: DWORD,
            args: *const c_void,
        ) -> DWORD;
    }
}

macro_rules! last_ident {
    ($i:ident) => {
        stringify!($i)
    };
    ($start:ident, $($cont:ident),+) => {
        last_ident!($($cont),+)
    };
}

macro_rules! os_call {
    ($($path:ident)::+ ($($args:expr),*), $success:expr) => {
        let result = unsafe{ $($path)::+ ($($args),+) };
        if !($success)(result) {
            let name = last_ident!($($path),+);
            let err_code = $crate::win::errno();
            Err($crate::win::OsError{
                function: name,
                error_code: err_code as u32,
                message: $crate::win::error_string(err_code)
            })?;
        }
    };
}

#[derive(Debug)]
pub struct OsError {
    pub function: &'static str,
    pub error_code: u32,
    pub message: String,
}

impl fmt::Display for OsError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl error::Error for OsError {
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
        None
    }
}

pub fn errno() -> i32 {
    unsafe { c::GetLastError() as i32 }
}

/// Gets a detailed string description for the given error number.
pub fn error_string(mut errnum: i32) -> String {
    // This value is calculated from the macro
    // MAKELANGID(LANG_SYSTEM_DEFAULT, SUBLANG_SYS_DEFAULT)
    let langId = 0x0800 as c::DWORD;

    let mut buf = [0 as c::WCHAR; 2048];

    unsafe {
        let mut module = ptr::null_mut();
        let mut flags = 0;

        // NTSTATUS errors may be encoded as HRESULT, which may returned from
        // GetLastError. For more information about Windows error codes, see
        // `[MS-ERREF]`: https://msdn.microsoft.com/en-us/library/cc231198.aspx
        if (errnum & c::FACILITY_NT_BIT as i32) != 0 {
            // format according to https://support.microsoft.com/en-us/help/259693
            const NTDLL_DLL: &[u16] = &[
                'N' as _, 'T' as _, 'D' as _, 'L' as _, 'L' as _, '.' as _, 'D' as _, 'L' as _,
                'L' as _, 0,
            ];
            module = c::GetModuleHandleW(NTDLL_DLL.as_ptr());

            if module != ptr::null_mut() {
                errnum ^= c::FACILITY_NT_BIT as i32;
                flags = c::FORMAT_MESSAGE_FROM_HMODULE;
            }
        }

        let res = c::FormatMessageW(
            flags | c::FORMAT_MESSAGE_FROM_SYSTEM | c::FORMAT_MESSAGE_IGNORE_INSERTS,
            module,
            errnum as c::DWORD,
            langId,
            buf.as_mut_ptr(),
            buf.len() as c::DWORD,
            ptr::null(),
        ) as usize;
        if res == 0 {
            // Sometimes FormatMessageW can fail e.g., system doesn't like langId,
            let fm_err = errno();
            return format!(
                "OS Error {} (FormatMessageW() returned error {})",
                errnum, fm_err
            );
        }

        match String::from_utf16(&buf[..res]) {
            Ok(mut msg) => {
                // Trim trailing CRLF inserted by FormatMessageW
                let len = msg.trim_end().len();
                msg.truncate(len);
                msg
            }
            Err(..) => format!(
                "OS Error {} (FormatMessageW() returned \
                 invalid UTF-16)",
                errnum
            ),
        }
    }
}