| import dis |
| import RightBack |
| import marshal |
| import types |
| from graphviz import Digraph |
| |
| |
| def get_funcs(module): |
| funcs = [] |
| for name in dir(module): |
| obj = getattr(module, name) |
| if isinstance(obj, types.FunctionType): |
| funcs.append(obj) |
| return funcs |
| |
| def get_extend_size(code, offset): |
| ptr = offset |
| |
| while True: |
| op = code[ptr] |
| if op != dis.EXTENDED_ARG: |
| break |
| ptr += 2 |
| |
| return ptr - offset |
| |
| |
| class InstWrapper: |
| def __init__(self, prefix_insn : list[dis.Instruction], main_insn : dis.Instruction): |
| self.offset = main_insn.offset |
| self.prefix_insn = prefix_insn |
| self.main_insn = main_insn |
| if len(prefix_insn) != 0: |
| self.offset = prefix_insn[0].offset |
| self.preds = set() |
| self.succs = set() |
| |
| def edge(self, v1): |
| v1.preds.add(self) |
| self.succs.add(v1) |
| |
| def unlink(self): |
| for n in self.preds: |
| n.succs.remove(self) |
| for n in self.succs: |
| n.preds.remove(self) |
| |
| def size(self): |
| return 2 * (len(self.prefix_insn) + 1) |
| |
| def print_asm(self): |
| assert len(self.succs) <= 2 |
| result = [] |
| if self.main_insn.opcode not in dis.hasjabs and self.main_insn.opcode not in dis.hasjrel: |
| for p in self.prefix_insn: |
| result += [p.opcode, p.arg] |
| result += [self.main_insn.opcode] |
| if self.main_insn.arg: |
| result.append(self.main_insn.arg) |
| else: |
| result.append(0) |
| return result |
| s = list(self.succs) |
| ref_target = None |
| for n in s: |
| if self.offset + self.size() != n.offset: |
| ref_target = n.offset |
| break |
| argval = ref_target |
| if self.main_insn.opcode in dis.hasjrel: |
| argval = ref_target - (self.offset + len(self.prefix_insn) * 2) - 2 |
| byte_size = len(self.prefix_insn) + 1 |
| argval &= (2 ** (8 * byte_size)) - 1 |
| raw = int.to_bytes(argval, byte_size, byteorder='big') |
| for i in range(len(self.prefix_insn)): |
| result += [dis.opmap['EXTENDED_ARG'], raw[i]] |
| result += [self.main_insn.opcode, raw[len(self.prefix_insn)]] |
| return result |
| |
| def __str__(self) -> str: |
| return '%04d %s\t(%s)' % (self.offset, self.main_insn.opname, str(self.main_insn.argval) if self.main_insn.argval else '') |
| |
| def fix_invalid(func): |
| raw_code = func.__code__.co_code |
| insn_map = {} |
| |
| for insn in dis._get_instructions_bytes(raw_code): |
| insn_map[insn.offset] = insn |
| |
| trace_result = dict() |
| |
| def trace_pc(code, insn_map, offset): |
| nonlocal trace_result |
| assert offset in insn_map |
| |
| extend_size = get_extend_size(code, offset) |
| cur_size = extend_size + 2 |
| cur = insn_map[offset + extend_size] |
| if offset not in trace_result.keys(): |
| trace_result.update({offset : cur_size}) |
| else: |
| return |
| |
| if cur.opname == 'RETURN_VALUE': |
| return |
| elif cur.opcode in dis.hasjabs + dis.hasjrel: |
| jump_target = cur.argval |
| if cur.opname == 'JUMP_FORWARD' or cur.opname == 'JUMP_ABSOLUTE': |
| trace_pc(code, insn_map, jump_target) |
| return |
| else: |
| trace_pc(code, insn_map, jump_target) |
| trace_pc(code, insn_map, offset + cur_size) |
| else: |
| trace_pc(code, insn_map, offset + cur_size) |
| |
| trace_pc(raw_code, insn_map, 0) |
| |
| flat_mem = [0 for i in range(len(raw_code))] |
| for k, v in trace_result.items(): |
| for i in range(v): |
| flat_mem[k + i] = 1 |
| new_code = [] |
| idx = 0 |
| for d in raw_code: |
| if flat_mem[idx] == 1: |
| new_code.append(d) |
| else: |
| new_code.append(dis.opmap['NOP']) |
| idx += 1 |
| |
| new_code = bytes(new_code) |
| return new_code |
| |
| def recompile(func, new_code): |
| all_wrap = {} |
| worker = {} |
| temp = [] |
| last = None |
| for insn in dis._get_instructions_bytes(new_code): |
| if insn.opname == 'EXTENDED_ARG': |
| temp.append(insn) |
| else: |
| inst = InstWrapper(temp.copy(), insn) |
| if inst.offset in worker.keys(): |
| for other in worker[inst.offset]: |
| other.edge(inst) |
| worker.pop(inst.offset) |
| if last and last.main_insn.opname != 'RETURN_VALUE' and last.main_insn.opname != 'JUMP_FORWARD' and last.main_insn.opname != 'JUMP_ABSOLUTE': |
| last.edge(inst) |
| last = inst |
| if inst.main_insn.opcode in dis.hasjabs + dis.hasjrel: |
| |
| target = inst.main_insn.argval |
| if target in all_wrap.keys(): |
| inst.edge(all_wrap[target]) |
| else: |
| if target in worker.keys(): |
| worker[target].add(inst) |
| else: |
| worker[target] = set([inst]) |
| all_wrap.update({inst.offset : inst}) |
| temp.clear() |
| |
| |
| all_wrap = list(all_wrap.values()) |
| while True: |
| to_remove = [] |
| for n in all_wrap: |
| if len(n.preds) == 0 and n.offset != 0: |
| n.unlink() |
| to_remove.append(n) |
| if len(n.preds) == 1 and len(n.succs) == 1 and n.main_insn.opname == 'JUMP_FORWARD': |
| n.unlink() |
| p = n.preds.pop() |
| s = n.succs.pop() |
| p.edge(s) |
| to_remove.append(n) |
| if len(to_remove) == 0: |
| break |
| for n in to_remove: |
| all_wrap.remove(n) |
| |
| entry_point = None |
| for i in all_wrap: |
| if i.offset == 0: |
| entry_point = i |
| break |
| |
| visited = set() |
| offset = 0 |
| def visit(cur : InstWrapper): |
| nonlocal visited, offset |
| cur.offset = offset |
| offset = offset + cur.size() |
| succs = sorted(list(cur.succs), key = lambda a : a.offset) |
| for node in succs: |
| if node not in visited: |
| visited.add(node) |
| visit(node) |
| |
| visit(entry_point) |
| new_bytecode = [] |
| for w in all_wrap: |
| new_bytecode += w.print_asm() |
| print(new_bytecode) |
| new_bytecode = bytes(new_bytecode) |
| |
| code = func.__code__ |
| new_code = types.CodeType( |
| code.co_argcount, |
| code.co_posonlyargcount, |
| code.co_kwonlyargcount, |
| code.co_nlocals, |
| code.co_stacksize, |
| code.co_flags, |
| new_bytecode, |
| code.co_consts, |
| code.co_names, |
| code.co_varnames, |
| code.co_filename, |
| func.__name__, |
| code.co_firstlineno, |
| code.co_lnotab, |
| code.co_freevars, |
| code.co_cellvars |
| ) |
| dot = Digraph(func.__name__) |
| for n in all_wrap: |
| dot.node(str(n.offset), str(n)) |
| |
| for n in all_wrap: |
| for pred in n.preds: |
| dot.edge(str(pred.offset), str(n.offset)) |
| |
| return new_code |
| |
| |
| file = open('RightBack.pyc', 'rb') |
| data = file.read() |
| file.close() |
| file = open('RightBack_fix.pyc', 'wb') |
| |
| for f in get_funcs(RightBack): |
| new_code = fix_invalid(f) |
| code_obj = recompile(f, new_code) |
| setattr(RightBack, f.__name__, code_obj) |
| |
| byc = open(f.__name__ + '.pyc', 'wb') |
| byc.write(b'\x61\x0D\x0D\x0A\x00\x00\x00\x00\xC1\xC5\xC0\x64\x9C\x27\x00\x00') |
| marshal.dump(code_obj, byc) |
| byc.close() |
| |
| |
| |
| file.write(data) |
| file.close() |