aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/test_by_xed.py
blob: 1e84c6aecbf1d78df55092bf2af6fe8e79bc7258 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import re
import math
import sys

class Reg:
  def __init__(self, s):
    self.name = s
  def __str__(self):
    return self.name
  def __eq__(self, rhs):
    return self.name == rhs.name
  def __lt__(self, rhs):
    return self.name < rhs.name

g_xmmTbl = '''
xmm0 xmm1 xmm2 xmm3 xmm4 xmm5 xmm6 xmm7
xmm8 xmm9 xmm10 xmm11 xmm12 xmm13 xmm14 xmm15
xmm16 xmm17 xmm18 xmm19 xmm20 xmm21 xmm22 xmm23
xmm24 xmm25 xmm26 xmm27 xmm28 xmm29 xmm30 xmm31
ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7
ymm8 ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15
ymm16 ymm17 ymm18 ymm19 ymm20 ymm21 ymm22 ymm23
ymm24 ymm25 ymm26 ymm27 ymm28 ymm29 ymm30 ymm31
zmm0 zmm1 zmm2 zmm3 zmm4 zmm5 zmm6 zmm7
zmm8 zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15
zmm16 zmm17 zmm18 zmm19 zmm20 zmm21 zmm22 zmm23
zmm24 zmm25 zmm26 zmm27 zmm28 zmm29 zmm30 zmm31
'''.split()

g_regTbl = '''
eax ecx edx ebx esp ebp esi edi
ax cx dx bx sp bp si di
al cl dl bl ah ch dh bh
k1 k2 k3 k4 k5 k6 k7
rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15
r16 r17 r18 r19 r20 r21 r22 r23 r24 r25 r26 r27 r28 r29 r30 r31
r8d r9d r10d r11d r12d r13d r14d r15d
r16d r17d r18d r19d r20d r21d r22d r23d r24d r25d r26d r27d r28d r29d r30d r31d
r8w r9w r10w r11w r12w r13w r14w r15w
r16w r17w r18w r19w r20w r21w r22w r23w r24w r25w r26w r27w r28w r29w r30w r31w
r8b r9b r10b r11b r12b r13b r14b r15b
r16b r17b r18b r19b r20b r21b r22b r23b r24b r25b r26b r27b r28b r29b r30b r31b
spl bpl sil dil
tmm0 tmm1 tmm2 tmm3 tmm4 tmm5 tmm6 tmm7
'''.split()+g_xmmTbl

# define global constants
for e in g_regTbl:
  globals()[e] = Reg(e)

g_maskTbl = [k1, k2, k3, k4, k5, k6, k7]

g_replaceCharTbl = '{}();|,'
g_replaceChar = str.maketrans(g_replaceCharTbl, ' '*len(g_replaceCharTbl))
g_sizeTbl = ['byte', 'word', 'dword', 'qword', 'xword', 'yword', 'zword']
g_xedSizeTbl = ['xmmword', 'ymmword', 'zmmword']
g_attrTbl = ['T_sae', 'T_rn_sae', 'T_rd_sae', 'T_ru_sae', 'T_rz_sae', 'T_z']
g_attrXedTbl = ['sae', 'rne-sae', 'rd-sae', 'ru-sae', 'rz-sae', 'z']

class Attr:
  def __init__(self, s):
    self.name = s
  def __str__(self):
    return self.name
  def __eq__(self, rhs):
    return self.name == rhs.name
  def __lt__(self, rhs):
    return self.name < rhs.name

for e in g_attrTbl:
  globals()[e] = Attr(e)

def newReg(s):
  if type(s) == str:
    return Reg(s)
  return s

class Memory:
  def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=False):
    self.size = size
    self.base = newReg(base)
    self.index = newReg(index)
    self.scale = scale
    self.disp = disp
    self.broadcast = broadcast

  def __str__(self):
    s = 'ptr' if self.size == 0 else g_sizeTbl[int(math.log2(self.size))]
    if self.broadcast:
      s += '_b'
    s += ' ['
    needPlus = False
    if self.base:
      s += str(self.base)
      needPlus = True
    if self.index:
      if needPlus:
        s += '+'
      s += str(self.index)
      if self.scale > 1:
        s += f'*{self.scale}'
      needPlus = True
    if self.disp:
      if needPlus:
        s += '+'
      s += hex(self.disp)
    s += ']'
    return s

  def __eq__(self, rhs):
    # xbyak uses ptr if it is automatically detected, so xword == ptr is true
    if self.broadcast != rhs.broadcast: return False
#    if not self.broadcast and 0 < self.size <= 8 and 0 < rhs.size <= 8 and self.size != rhs.size: return False
    if not self.broadcast and self.size > 0 and rhs.size > 0 and self.size != rhs.size: return False
    r = self.base == rhs.base and self.index == rhs.index and self.scale == rhs.scale and self.disp == rhs.disp
    return r

def parseBroadcast(s):
  if '_b' in s:
    return (s.replace('_b', ''), True)
  r = re.search(r'({1to\d+})', s)
  if not r:
    return (s, False)
  return (s.replace(r.group(1), ''), True)

def parseMemory(s, broadcast=False):
  org_s = s

  s = s.replace(' ', '').lower()

  size = 0
  base = index = None
  scale = 0
  disp = 0

  if not broadcast:
    (s, broadcast) = parseBroadcast(s)

  # Parse size
  for i in range(len(g_sizeTbl)):
    w = g_sizeTbl[i]
    if s.startswith(w):
      size = 1<<i
      s = s[len(w):]
      break

  if size == 0:
    for i in range(len(g_xedSizeTbl)):
      w = g_xedSizeTbl[i]
      if s.startswith(w):
        size = 1<<(i+4)
        s = s[len(w):]
        break

  # Remove 'ptr' if present
  if s.startswith('ptr'):
    s = s[3:]

  if s.startswith('_b'):
    broadcast = True
    s = s[2:]

  # Extract the content inside brackets
  r = re.match(r'\[(.*)\]', s)
  if not r:
    raise ValueError(f'bad format {org_s=}')

  # Parse components
  elems = re.findall(r'([a-z0-9]+)(?:\*([0-9]+))?|([+-])', r.group(1))

  for i, e in enumerate(elems):
    if e[2]: # This is a '+' or '-' sign
      continue

    if e[0] in g_regTbl:
      if base is None and (not e[1] or int(e[1]) == 1):
        base = e[0]
      elif index is None:
        index = e[0]
        scale = int(e[1]) if e[1] else 1
      else:
        raise ValueError(f'bad format2 {s=}')
    else:
      sign = -1 if i > 0 and elems[i-1][2] == '-' else 1
      b = 16 if e[0].startswith('0x') else 10
      disp += sign * int(e[0], b)

  return Memory(size, base, index, scale, disp, broadcast)

class Nmemonic:
  def __init__(self, name, args=[], attrs=[]):
    self.name = name
    self.args = args
    self.attrs = attrs.sort()
  def __str__(self):
    s = f'{self.name}('
    for i in range(len(self.args)):
      if i > 0:
        s += ', '
      s += str(self.args[i])
      if i == 0 and self.attrs:
        for e in self.attrs:
          s += f'|{e}'
    s += ');'
    return s
  def __eq__(self, rhs):
    return self.name == rhs.name and self.args == rhs.args and self.attrs == rhs.attrs

def parseNmemonic(s):
  args = []
  attrs = []

  # remove Xbyak::{Evex,Vex}Encoding
  r = re.search(r'(,[^,]*Encoding)', s)
  if r:
    s = s.replace(r.group(1), '')

  (s, broadcast) = parseBroadcast(s)

  # replace xm0 with xmm0
  while True:
    r = re.search(r'([xyz])m(\d\d?)', s)
    if not r:
      break
    s = s.replace(r.group(0), r.group(1) + 'mm' + r.group(2))

  # check 'zmm0{k7}'
  r = re.search(r'({k[1-7]})', s)
  if r:
    idx = int(r.group(1)[2])
    attrs.append(g_maskTbl[idx-1])
    s = s.replace(r.group(1), '')
  # check 'zmm0|k7'
  r = re.search(r'(\|\s*k[1-7])', s)
  if r:
    idx = int(r.group(1)[-1])
    attrs.append(g_maskTbl[idx-1])
    s = s.replace(r.group(1), '')

  s = s.translate(g_replaceChar)

  # reconstruct memory string
  v = []
  inMemory = False
  for e in s.split():
    if inMemory:
      v[-1] += e
      if ']' in e:
        inMemory = False
    else:
      v.append(e)
      if e in g_sizeTbl or e in g_xedSizeTbl or e.startswith('ptr'):
        v[-1] += ' ' # to avoid 'byteptr'
        if ']' not in v[-1]:
          inMemory = True

  name = v[0]
  for e in v[1:]:
    if e.startswith('0x'):
      args.append(int(e, 16))
    elif e[0] in '0123456789':
      args.append(int(e))
    elif e in g_attrTbl:
      attrs.append(Attr(e))
    elif e in g_attrXedTbl:
      attrs.append(Attr(g_attrTbl[g_attrXedTbl.index(e)]))
    elif e in g_regTbl:
      args.append(Reg(e))
    # xed special format : xmm8+3
    elif e[:-2] in g_xmmTbl and e.endswith('+3'):
      args.append(Reg(e[:-2]))
    else:
      args.append(parseMemory(e, broadcast))
  return Nmemonic(name, args, attrs)

def loadFile(name):
  with open(name) as f:
    r = []
    for line in f.read().split('\n'):
      if line:
        if line[0] == '#' or line.startswith('//'):
          continue
        r.append(line)
    return r

# remove top 5 information
# e.g. XDIS 0: AVX512    AVX512EVEX 62F1E91858CB             vaddpd ymm1{rne-sae}, ymm2, ymm3
def removeExtraInfo(s):
  v = s.split()
  return ' '.join(v[5:])

def run(cppText, xedText):
  cpp = loadFile(cppText)
  xed = loadFile(xedText)
  n = len(cpp)
  if n != len(xed):
    raise Exception(f'different line {n} {len(xed)}')

  for i in range(n):
    line1 = cpp[i]
    line2 = removeExtraInfo(xed[i])
    m1 = parseNmemonic(line1)
    m2 = parseNmemonic(line2)

    assertEqual(m1, m2, f'{i+1}')
  print('run ok', n)

def assertEqualStr(a, b, msg=None):
  if str(a) != str(b):
    raise Exception(f'assert fail {msg}:', str(a), str(b))

def assertEqual(a, b, msg=None):
  if a != b:
    raise Exception(f'assert fail {msg}:', str(a), str(b))

def MemoryTest():
  tbl = [
    (Memory(0, rax), 'ptr [rax]'),
    (Memory(4, rax), 'dword [rax]'),
    (Memory(8, rax, rcx), 'qword [rax+rcx]'),
    (Memory(8, rax, rcx, 4), 'qword [rax+rcx*4]'),
    (Memory(8, None, rcx, 4), 'qword [rcx*4]'),
    (Memory(8, rax, None, 0, 5), 'qword [rax+0x5]'),
    (Memory(8, None, None, 0, 255), 'qword [0xff]'),
    (Memory(0, r8, r9, 1, 32), 'ptr [r8+r9+0x20]'),
  ]
  for (m, expected) in tbl:
    assertEqualStr(m, expected)

  assertEqual(Memory(16, rax), Memory(0, rax))

def parseMemoryTest():
  print('parseMemoryTest')
  tbl = [
    ('[]', Memory()),
    ('[rax]', Memory(0, rax)),
    ('ptr[rax]', Memory(0, rax)),
    ('ptr_b[rax]', Memory(0, rax, broadcast=True)),
    ('dword[rbx]', Memory(4, rbx)),
    ('xword ptr[rcx]', Memory(16, rcx)),
    ('xmmword ptr[rcx]', Memory(16, rcx)),
    ('xword ptr[rdx*8]', Memory(16, None, rdx, 8)),
    ('[12345]', Memory(0, None, None, 0, 12345)),
    ('[0x12345]', Memory(0, None, None, 0, 0x12345)),
    ('yword [rax+rdx*4]', Memory(32, rax, rdx, 4)),
    ('zword [rax+rdx*4+123]', Memory(64, rax, rdx, 4, 123)),
  ]
  for (s, expected) in tbl:
    my = parseMemory(s)
    assertEqualStr(my, expected)

def parseNmemonicTest():
  print('parseNmemonicTest')
  tbl = [
    ('vaddpd(ymm1, ymm2, ymm3 |T_rn_sae);', Nmemonic('vaddpd', [ymm1, ymm2, ymm3], [T_rn_sae])),
    ('vaddpd ymm1{rne-sae}, ymm2, ymm3', Nmemonic('vaddpd', [ymm1, ymm2, ymm3], [T_rn_sae])),
    ('mov(rax, dword ptr [rcx + rdx * 8 ] );', Nmemonic('mov', [rax, Memory(4, rcx, rdx, 8)])),
    ('mov(rax, ptr [rcx + rdx * 8 ] );', Nmemonic('mov', [rax, Memory(0, rcx, rdx, 8)])),
    ('vcmppd(k1, ymm2, ymm3 |T_sae, 3);', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])),
    ('vcmppd k1{sae}, ymm2, ymm3, 0x3', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])),
    ('v4fmaddps zmm1, zmm8+3, xmmword ptr [rdx+0x40]', Nmemonic('v4fmaddps', [zmm1, zmm8, Memory(16, rdx, None, 0, 0x40)])),
    ('vp4dpwssd zmm23{k7}{z}, zmm1+3, xmmword ptr [rax+0x40]', Nmemonic('vp4dpwssd', [zmm23, zmm1, Memory(16, rax, None, 0, 0x40)], [k7, T_z])),
    ('v4fnmaddps(zmm5 | k5, zmm2, ptr [rcx + 0x80]);', Nmemonic('v4fnmaddps', [zmm5, zmm2, Memory(0, rcx, None, 0, 0x80)], [k5])),
    ('vpcompressw(zmm30 | k2 |T_z, zmm1);', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
    ('vpcompressw zmm30{k2}{z}, zmm1', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
    ('vpshldw(xmm9|k3|T_z, xmm2, ptr [rax + 0x40], 5);', Nmemonic('vpshldw', [xmm9, xmm2, Memory(0, rax, None, 0, 0x40), 5], [k3, T_z])),
    ('vpshrdd(xmm5|k3|T_z, xmm2, ptr_b [rax + 0x40], 5);', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
    ('vpshrdd xmm5{k3}{z}, xmm2, dword ptr [rax+0x40]{1to4}, 0x5', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
    ('vcmpph(k1, xmm15, ptr[rax+64], 1);', Nmemonic('vcmpph', [k1, xmm15, Memory(0, rax, None, 0, 64), 1])),
  ]
  for (s, expected) in tbl:
    e = parseNmemonic(s)
    assertEqual(e, expected)

def test():
  print('test start')
  MemoryTest()
  parseMemoryTest()
  parseNmemonicTest()
  print('test end')

def main():
  if len(sys.argv) == 2 and sys.argv[1] == 'test':
    test()
  elif len(sys.argv) == 3:
    run(sys.argv[1], sys.argv[2])
  else:
    print(f'{__name__} <cpp-text> <xed-text> # compare cpp-text and xed-text generated by xed')
    print(f'{__name__} test # for test')

if __name__ == '__main__':
  main()