diff options
Diffstat (limited to 'test/test_by_xed.py')
-rw-r--r-- | test/test_by_xed.py | 184 |
1 files changed, 140 insertions, 44 deletions
diff --git a/test/test_by_xed.py b/test/test_by_xed.py index f24d7f6..3e4b98f 100644 --- a/test/test_by_xed.py +++ b/test/test_by_xed.py @@ -7,6 +7,25 @@ class Reg: self.name = s def __str__(self): return self.name + def __eq__(self, rhs): + return self.name == rhs.name + def __lt__(self, rhs): + return self.name < rhs.name + +g_xmmTbl = ''' +xmm0 xmm1 xmm2 xmm3 xmm4 xmm5 xmm6 xmm7 +xmm8 xmm9 xmm10 xmm11 xmm12 xmm13 xmm14 xmm15 +xmm16 xmm17 xmm18 xmm19 xmm20 xmm21 xmm22 xmm23 +xmm24 xmm25 xmm26 xmm27 xmm28 xmm29 xmm30 xmm31 +ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 +ymm8 ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 +ymm16 ymm17 ymm18 ymm19 ymm20 ymm21 ymm22 ymm23 +ymm24 ymm25 ymm26 ymm27 ymm28 ymm29 ymm30 ymm31 +zmm0 zmm1 zmm2 zmm3 zmm4 zmm5 zmm6 zmm7 +zmm8 zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 +zmm16 zmm17 zmm18 zmm19 zmm20 zmm21 zmm22 zmm23 +zmm24 zmm25 zmm26 zmm27 zmm28 zmm29 zmm30 zmm31 +'''.split() g_regTbl = ''' eax ecx edx ebx esp ebp esi edi @@ -22,49 +41,53 @@ r16w r17w r18w r19w r20w r21w r22w r23w r24w r25w r26w r27w r28w r29w r30w r31w r8b r9b r10b r11b r12b r13b r14b r15b r16b r17b r18b r19b r20b r21b r22b r23b r24b r25b r26b r27b r28b r29b r30b r31b spl bpl sil dil -xmm0 xmm1 xmm2 xmm3 xmm4 xmm5 xmm6 xmm7 -xmm8 xmm9 xmm10 xmm11 xmm12 xmm13 xmm14 xmm15 -xmm16 xmm17 xmm18 xmm19 xmm20 xmm21 xmm22 xmm23 -xmm24 xmm25 xmm26 xmm27 xmm28 xmm29 xmm30 xmm31 -ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 -ymm8 ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 -ymm16 ymm17 ymm18 ymm19 ymm20 ymm21 ymm22 ymm23 -ymm24 ymm25 ymm26 ymm27 ymm28 ymm29 ymm30 ymm31 -zmm0 zmm1 zmm2 zmm3 zmm4 zmm5 zmm6 zmm7 -zmm8 zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 -zmm16 zmm17 zmm18 zmm19 zmm20 zmm21 zmm22 zmm23 -zmm24 zmm25 zmm26 zmm27 zmm28 zmm29 zmm30 zmm31 -'''.split() +tmm0 tmm1 tmm2 tmm3 tmm4 tmm5 tmm6 tmm7 +'''.split()+g_xmmTbl # define global constants for e in g_regTbl: globals()[e] = Reg(e) +g_maskTbl = [k1, k2, k3, k4, k5, k6, k7] + g_replaceCharTbl = '{}();|,' g_replaceChar = str.maketrans(g_replaceCharTbl, ' '*len(g_replaceCharTbl)) g_sizeTbl = ['byte', 'word', 'dword', 'qword', 'xword', 'yword', 'zword'] -g_attrTbl = ['T_sae', 'T_rn_sae', 'T_rd_sae', 'T_ru_sae', 'T_rz_sae'] #, 'T_z'] -g_attrXedTbl = ['sae', 'rne-sae', 'rd-sae', 'ru-sae', 'rz-sae'] +g_xedSizeTbl = ['xmmword', 'ymmword', 'zmmword'] +g_attrTbl = ['T_sae', 'T_rn_sae', 'T_rd_sae', 'T_ru_sae', 'T_rz_sae', 'T_z'] +g_attrXedTbl = ['sae', 'rne-sae', 'rd-sae', 'ru-sae', 'rz-sae', 'z'] class Attr: def __init__(self, s): self.name = s def __str__(self): return self.name + def __eq__(self, rhs): + return self.name == rhs.name + def __lt__(self, rhs): + return self.name < rhs.name for e in g_attrTbl: globals()[e] = Attr(e) +def newReg(s): + if type(s) == str: + return Reg(s) + return s + class Memory: - def __init__(self, size=0, base=None, index=None, scale=0, disp=0): + def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=False): self.size = size - self.base = base - self.index = index + self.base = newReg(base) + self.index = newReg(index) self.scale = scale self.disp = disp + self.broadcast = broadcast def __str__(self): s = 'ptr' if self.size == 0 else g_sizeTbl[int(math.log2(self.size))] + if self.broadcast: + s += '_b' s += ' [' needPlus = False if self.base: @@ -84,47 +107,72 @@ class Memory: s += ']' return s - def __eq__(self, rhs): - return str(self) == str(rhs) + # xbyak uses ptr if it is automatically detected, so xword == ptr is true + if self.broadcast != rhs.broadcast: return False +# if not self.broadcast and 0 < self.size <= 8 and 0 < rhs.size <= 8 and self.size != rhs.size: return False + if not self.broadcast and self.size > 0 and rhs.size > 0 and self.size != rhs.size: return False + r = self.base == rhs.base and self.index == rhs.index and self.scale == rhs.scale and self.disp == rhs.disp + return r -def parseMemory(s): - sizeTbl = { - 'byte': 1, 'word': 2, 'dword': 4, 'qword': 8, - 'xword': 16, 'yword': 32, 'zword': 64 - } +def parseBroadcast(s): + if '_b' in s: + return (s.replace('_b', ''), True) + r = re.search(r'({1to\d+})', s) + if not r: + return (s, False) + return (s.replace(r.group(1), ''), True) + +def parseMemory(s, broadcast=False): + org_s = s s = s.replace(' ', '').lower() - # Parse size size = 0 + base = index = None + scale = 0 + disp = 0 + + if not broadcast: + (s, broadcast) = parseBroadcast(s) + + # Parse size for i in range(len(g_sizeTbl)): w = g_sizeTbl[i] if s.startswith(w): size = 1<<i s = s[len(w):] + break + + if size == 0: + for i in range(len(g_xedSizeTbl)): + w = g_xedSizeTbl[i] + if s.startswith(w): + size = 1<<(i+4) + s = s[len(w):] + break # Remove 'ptr' if present if s.startswith('ptr'): s = s[3:] + if s.startswith('_b'): + broadcast = True + s = s[2:] + # Extract the content inside brackets r = re.match(r'\[(.*)\]', s) if not r: - raise ValueError(f'bad format {s=}') + raise ValueError(f'bad format {org_s=}') # Parse components elems = re.findall(r'([a-z0-9]+)(?:\*([0-9]+))?|([+-])', r.group(1)) - base = index = None - scale = 0 - disp = 0 - for i, e in enumerate(elems): if e[2]: # This is a '+' or '-' sign continue - if e[0].isalpha(): + if e[0] in g_regTbl: if base is None and (not e[1] or int(e[1]) == 1): base = e[0] elif index is None: @@ -137,25 +185,53 @@ def parseMemory(s): b = 16 if e[0].startswith('0x') else 10 disp += sign * int(e[0], b) - return Memory(size, base, index, scale, disp) + return Memory(size, base, index, scale, disp, broadcast) class Nmemonic: def __init__(self, name, args=[], attrs=[]): self.name = name self.args = args - self.attrs = attrs + self.attrs = attrs.sort() def __str__(self): s = f'{self.name}(' for i in range(len(self.args)): if i > 0: s += ', ' s += str(self.args[i]) - for e in self.attrs: - s += f'|{e}' + if i == 0 and self.attrs: + for e in self.attrs: + s += f'|{e}' s += ');' return s + def __eq__(self, rhs): + return self.name == rhs.name and self.args == rhs.args and self.attrs == rhs.attrs def parseNmemonic(s): + args = [] + attrs = [] + + (s, broadcast) = parseBroadcast(s) + + # replace xm0 with xmm0 + while True: + r = re.search(r'([xyz])m(\d\d?)', s) + if not r: + break + s = s.replace(r.group(0), r.group(1) + 'mm' + r.group(2)) + + # check 'zmm0{k7}' + r = re.search(r'({k[1-7]})', s) + if r: + idx = int(r.group(1)[2]) + attrs.append(g_maskTbl[idx-1]) + s = s.replace(r.group(1), '') + # check 'zmm0|k7' + r = re.search(r'(\|\s*k[1-7])', s) + if r: + idx = int(r.group(1)[-1]) + attrs.append(g_maskTbl[idx-1]) + s = s.replace(r.group(1), '') + s = s.translate(g_replaceChar) # reconstruct memory string @@ -168,13 +244,12 @@ def parseNmemonic(s): inMemory = False else: v.append(e) - if e in g_sizeTbl or e == 'ptr': + if e in g_sizeTbl or e in g_xedSizeTbl or e.startswith('ptr'): v[-1] += ' ' # to avoid 'byteptr' - inMemory = True + if ']' not in v[-1]: + inMemory = True name = v[0] - args = [] - attrs = [] for e in v[1:]: if e.startswith('0x'): args.append(int(e, 16)) @@ -185,9 +260,12 @@ def parseNmemonic(s): elif e in g_attrXedTbl: attrs.append(Attr(g_attrTbl[g_attrXedTbl.index(e)])) elif e in g_regTbl: - args.append(e) + args.append(Reg(e)) + # xed special format : xmm8+3 + elif e[:-2] in g_xmmTbl and e.endswith('+3'): + args.append(Reg(e[:-2])) else: - args.append(parseMemory(e)) + args.append(parseMemory(e, broadcast)) return Nmemonic(name, args, attrs) def loadFile(name): @@ -215,13 +293,17 @@ def run(cppText, xedText): m1 = parseNmemonic(line1) m2 = parseNmemonic(line2) - assertEqualStr(m1, m2, f'{i}') + assertEqual(m1, m2, f'{i+1}') print('run ok') def assertEqualStr(a, b, msg=None): if str(a) != str(b): raise Exception(f'assert fail {msg}:', str(a), str(b)) +def assertEqual(a, b, msg=None): + if a != b: + raise Exception(f'assert fail {msg}:', str(a), str(b)) + def MemoryTest(): tbl = [ (Memory(0, rax), 'ptr [rax]'), @@ -231,18 +313,23 @@ def MemoryTest(): (Memory(8, None, rcx, 4), 'qword [rcx*4]'), (Memory(8, rax, None, 0, 5), 'qword [rax+0x5]'), (Memory(8, None, None, 0, 255), 'qword [0xff]'), + (Memory(0, r8, r9, 1, 32), 'ptr [r8+r9+0x20]'), ] for (m, expected) in tbl: assertEqualStr(m, expected) + assertEqual(Memory(16, rax), Memory(0, rax)) + def parseMemoryTest(): print('parseMemoryTest') tbl = [ ('[]', Memory()), ('[rax]', Memory(0, rax)), ('ptr[rax]', Memory(0, rax)), + ('ptr_b[rax]', Memory(0, rax, broadcast=True)), ('dword[rbx]', Memory(4, rbx)), ('xword ptr[rcx]', Memory(16, rcx)), + ('xmmword ptr[rcx]', Memory(16, rcx)), ('xword ptr[rdx*8]', Memory(16, None, rdx, 8)), ('[12345]', Memory(0, None, None, 0, 12345)), ('[0x12345]', Memory(0, None, None, 0, 0x12345)), @@ -262,10 +349,19 @@ def parseNmemonicTest(): ('mov(rax, ptr [rcx + rdx * 8 ] );', Nmemonic('mov', [rax, Memory(0, rcx, rdx, 8)])), ('vcmppd(k1, ymm2, ymm3 |T_sae, 3);', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])), ('vcmppd k1{sae}, ymm2, ymm3, 0x3', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])), + ('v4fmaddps zmm1, zmm8+3, xmmword ptr [rdx+0x40]', Nmemonic('v4fmaddps', [zmm1, zmm8, Memory(16, rdx, None, 0, 0x40)])), + ('vp4dpwssd zmm23{k7}{z}, zmm1+3, xmmword ptr [rax+0x40]', Nmemonic('vp4dpwssd', [zmm23, zmm1, Memory(16, rax, None, 0, 0x40)], [k7, T_z])), + ('v4fnmaddps(zmm5 | k5, zmm2, ptr [rcx + 0x80]);', Nmemonic('v4fnmaddps', [zmm5, zmm2, Memory(0, rcx, None, 0, 0x80)], [k5])), + ('vpcompressw(zmm30 | k2 |T_z, zmm1);', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])), + ('vpcompressw zmm30{k2}{z}, zmm1', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])), + ('vpshldw(xmm9|k3|T_z, xmm2, ptr [rax + 0x40], 5);', Nmemonic('vpshldw', [xmm9, xmm2, Memory(0, rax, None, 0, 0x40), 5], [k3, T_z])), + ('vpshrdd(xmm5|k3|T_z, xmm2, ptr_b [rax + 0x40], 5);', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])), + ('vpshrdd xmm5{k3}{z}, xmm2, dword ptr [rax+0x40]{1to4}, 0x5', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])), + ('vcmpph(k1, xm15, ptr[rax+64], 1);', Nmemonic('vcmpph', [k1, xm15, Memory(0, rax, None, 0, 64), 1])), ] for (s, expected) in tbl: e = parseNmemonic(s) - assertEqualStr(e, expected) + assertEqual(e, expected) def test(): print('test start') |