初探python bytecode

# preface: 起因是 wmctf 的一道 python 题,包括了修复 python 字节码花指令和虚拟机等考点。算是萌新初见 python 花,花了一上午学习修了花指令,反编译不了不过可以获得字节码,理论上手动反汇编得到 python 码就没问题了。但实际上这里的字节码结构也有点问题,顺便从学长那里听闻了一下下 antlr 这种东西。然后手动翻译的伪代码也错漏百出… 最后是学长写了脚本给他自动化修了得以告一段落。不过确实得回来补补 python 的编译结构功课了。
# 题目是 3.9 的,不过在这里就可以直接搜其他版本的也一样 cpython 3.9
# image-20230823153127421
# 这里的 co_code 是字节码的长度标识,在 patch 时需要注意修复长度,其他标志位也需要注意(赛时就是 co_lnotab 对齐问题导致的无法反编译)

image-20230823154536966

# cpython/Objects/lnotab_notes.txt at 3.9 · python/cpython (github.com) 可以参考文档
# 去花脚本(R1mao 那里偷来的):
import dis
import RightBack
import marshal
import types
from graphviz import Digraph
 
def get_funcs(module):
    funcs = []
    for name in dir(module):
        obj = getattr(module, name)
        if isinstance(obj, types.FunctionType):
           funcs.append(obj)
    return funcs
def get_extend_size(code, offset):
    ptr = offset
    while True:
        op = code[ptr]
        if op != dis.EXTENDED_ARG:
            break
        ptr += 2
    return ptr - offset
    
class InstWrapper:
    def __init__(self, prefix_insn : list[dis.Instruction], main_insn : dis.Instruction):
        self.offset = main_insn.offset
        self.prefix_insn = prefix_insn
        self.main_insn = main_insn
        if len(prefix_insn) != 0:
            self.offset = prefix_insn[0].offset
        self.preds = set()
        self.succs = set()
    def edge(self, v1):
        v1.preds.add(self)
        self.succs.add(v1)
    def unlink(self):
        for n in self.preds:
            n.succs.remove(self)
        for n in self.succs:
            n.preds.remove(self)
    def size(self):
        return 2 * (len(self.prefix_insn) + 1)
    def print_asm(self):
        assert len(self.succs) <= 2
        result = []
        if self.main_insn.opcode not in dis.hasjabs and self.main_insn.opcode not in dis.hasjrel:
            for p in self.prefix_insn:
                result += [p.opcode, p.arg]
            result += [self.main_insn.opcode]
            if self.main_insn.arg:
                result.append(self.main_insn.arg)
            else:
                result.append(0)
            return result
        s = list(self.succs)
        ref_target = None
        for n in s:
            if self.offset + self.size() != n.offset:
                ref_target = n.offset
                break
        argval = ref_target
        if self.main_insn.opcode in dis.hasjrel:
            argval = ref_target - (self.offset + len(self.prefix_insn) * 2) - 2
        byte_size = len(self.prefix_insn) + 1
        argval &= (2 ** (8 * byte_size)) - 1
        raw = int.to_bytes(argval, byte_size, byteorder='big')
        for i in range(len(self.prefix_insn)):
            result += [dis.opmap['EXTENDED_ARG'], raw[i]]
        result += [self.main_insn.opcode, raw[len(self.prefix_insn)]]
        return result
            
    def __str__(self) -> str:
        return '%04d %s\t(%s)' % (self.offset, self.main_insn.opname, str(self.main_insn.argval) if self.main_insn.argval else '')
    
def fix_invalid(func):
    raw_code = func.__code__.co_code
    insn_map = {}
    for insn in dis._get_instructions_bytes(raw_code):
        insn_map[insn.offset] = insn
    trace_result = dict()
    def trace_pc(code, insn_map, offset):
        nonlocal trace_result
        assert offset in insn_map
        
        extend_size = get_extend_size(code, offset)
        cur_size = extend_size + 2
        cur = insn_map[offset + extend_size]
        if offset not in trace_result.keys():
            trace_result.update({offset : cur_size})
        else:
            return
        
        if cur.opname == 'RETURN_VALUE':
            return
        elif cur.opcode in dis.hasjabs + dis.hasjrel:
            jump_target = cur.argval
            if cur.opname == 'JUMP_FORWARD' or cur.opname == 'JUMP_ABSOLUTE':
                trace_pc(code, insn_map, jump_target)
                return
            else:
                trace_pc(code, insn_map, jump_target)
                trace_pc(code, insn_map, offset + cur_size)
        else:
            trace_pc(code, insn_map, offset + cur_size)
    trace_pc(raw_code, insn_map, 0)
    flat_mem = [0 for i in range(len(raw_code))]
    for k, v in trace_result.items():
        for i in range(v):
            flat_mem[k + i] = 1
    new_code = []
    idx = 0 
    for d in raw_code:
        if flat_mem[idx] == 1:
            new_code.append(d)
        else:
            new_code.append(dis.opmap['NOP'])
        idx += 1
    new_code = bytes(new_code)
    return new_code
def recompile(func, new_code):
    all_wrap = {}
    worker = {}
    temp = []
    last = None
    for insn in dis._get_instructions_bytes(new_code):
        if insn.opname == 'EXTENDED_ARG':
            temp.append(insn)
        else:
            inst = InstWrapper(temp.copy(), insn)
            if inst.offset in worker.keys():
                for other in worker[inst.offset]:
                    other.edge(inst)
                worker.pop(inst.offset)
            if last and last.main_insn.opname != 'RETURN_VALUE' and last.main_insn.opname != 'JUMP_FORWARD' and last.main_insn.opname != 'JUMP_ABSOLUTE':
                last.edge(inst)
            last = inst
            if inst.main_insn.opcode in dis.hasjabs + dis.hasjrel:
                
                target = inst.main_insn.argval
                if target in all_wrap.keys():
                    inst.edge(all_wrap[target])
                else:
                    if target in worker.keys():
                        worker[target].add(inst)
                    else:
                        worker[target] = set([inst])
            all_wrap.update({inst.offset : inst})
            temp.clear()
    all_wrap = list(all_wrap.values())
    while True:
        to_remove = []
        for n in all_wrap:
            if len(n.preds) == 0 and n.offset != 0:
                n.unlink()
                to_remove.append(n)
            if len(n.preds) == 1 and len(n.succs) == 1 and n.main_insn.opname == 'JUMP_FORWARD':
                n.unlink()
                p = n.preds.pop()
                s = n.succs.pop()
                p.edge(s)
                to_remove.append(n)
        if len(to_remove) == 0:
            break
        for n in to_remove:
            all_wrap.remove(n)
    entry_point = None
    for i in all_wrap:
        if i.offset == 0:
            entry_point = i
            break
    
    visited = set()
    offset = 0
    def visit(cur : InstWrapper):
        nonlocal visited, offset
        cur.offset = offset
        offset = offset + cur.size()
        succs = sorted(list(cur.succs), key = lambda a : a.offset)
        for node in succs:
            if node not in visited:
                visited.add(node)
                visit(node)
    visit(entry_point)
    new_bytecode = []
    for w in all_wrap:
        new_bytecode += w.print_asm()
    print(new_bytecode)
    new_bytecode = bytes(new_bytecode)
    code = func.__code__
    new_code = types.CodeType(
        code.co_argcount,
        code.co_posonlyargcount,
        code.co_kwonlyargcount,
        code.co_nlocals,
        code.co_stacksize,
        code.co_flags,
        new_bytecode, 
        code.co_consts,
        code.co_names,
        code.co_varnames,
        code.co_filename,
        func.__name__,
        code.co_firstlineno,
        code.co_lnotab,
        code.co_freevars,
        code.co_cellvars
    )
    dot = Digraph(func.__name__)
    for n in all_wrap:
        dot.node(str(n.offset), str(n))
        
    for n in all_wrap:
        for pred in n.preds:
            dot.edge(str(pred.offset), str(n.offset))
    #dot.view()
    return new_code
file = open('RightBack.pyc', 'rb')
data = file.read()
file.close()
file = open('RightBack_fix.pyc', 'wb')
for f in get_funcs(RightBack):
    new_code = fix_invalid(f)
    code_obj = recompile(f, new_code)
    setattr(RightBack, f.__name__, code_obj)
    byc = open(f.__name__ + '.pyc', 'wb')
    byc.write(b'\x61\x0D\x0D\x0A\x00\x00\x00\x00\xC1\xC5\xC0\x64\x9C\x27\x00\x00')
    marshal.dump(code_obj, byc)
    byc.close()
    #raw_code = f.__code__.co_code
    #data = data.replace(raw_code, new_code)
file.write(data)
file.close()