【譯】使用 Python 編寫虛擬機器直譯器

OneAPM官方技術部落格發表於2015-06-19

原文地址:Making a simple VM interpreter in Python

更新:根據大家的評論我對程式碼做了輕微的改動。感謝 robin-gvx、 bs4h 和 Dagur,具體程式碼見這裡

Stack Machine 本身並沒有任何的暫存器,它將所需要處理的值全部放入堆疊中而後進行處理。Stack Machine 雖然簡單但是卻十分強大,這也是為神馬 Python,Java,PostScript,Forth 和其他語言都選擇它作為自己的虛擬機器的原因。

首先,我們先來談談堆疊。我們需要一個指令指標棧用於儲存返回地址。這樣當我們呼叫了一個子例程(比如呼叫一個函式)的時候我們就能夠返回到我們開始呼叫的地方了。我們可以使用自修改程式碼(self-modifying code)來做這件事,恰如 Donald Knuth 發起的 MIX 所做的那樣。但是如果這麼做的話你不得不自己維護堆疊從而保證遞迴能正常工作。在這篇文章中,我並不會真正的實現子例程呼叫,但是要實現它其實並不難(可以考慮把實現它當成練習)。

有了堆疊之後你會省很多事兒。舉個例子來說,考慮這樣一個表示式(2+3)*4。在 Stack Machine 上與這個表示式等價的程式碼為 2 3 + 4 *。首先,將 23 推入堆疊中,接下來的是操作符 +,此時讓堆疊彈出這兩個數值,再把它兩加合之後的結果重新入棧。然後將 4 入堆,而後讓堆疊彈出兩個數值,再把他們相乘之後的結果重新入棧。多麼簡單啊!

讓我們開始寫一個簡單的堆疊類吧。讓這個類繼承 collections.deque

from collections import deque

class Stack(deque):
push = deque.append

def top(self):
    return self[-1]

現在我們有了 pushpoptop 這三個方法。top 方法用於檢視棧頂元素。

接下來,我們實現虛擬機器這個類。在虛擬機器中我們需要兩個堆疊以及一些記憶體空間來儲存程式本身(譯者注:這裡的程式請結合下文理解)。得益於 Pyhton 的動態型別我們可以往 list 中放入任何型別。唯一的問題是我們無法區分出哪些是字串哪些是內建函式。正確的做法是隻將真正的 Python 函式放入 list 中。我可能會在將來實現這一點。

我們同時還需要一個指令指標指向程式中下一個要執行的程式碼。

class Machine:
def __init__(self, code):
    self.data_stack = Stack()
    self.return_addr_stack = Stack()
    self.instruction_pointer = 0
    self.code = code

這時候我們增加一些方便使用的函式省得以後多敲鍵盤。

def pop(self):
    return self.data_stack.pop()

def push(self, value):
    self.data_stack.push(value)

def top(self):
    return self.data_stack.top()

然後我們增加一個 dispatch 函式來完成每一個操作碼做的事兒(我們並不是真正的使用操作碼,只是動態展開它,你懂的)。首先,增加一個直譯器所必須的迴圈:

def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)

誠如您所見的,這貨只好好的做一件事兒,即獲取下一條指令,讓指令指標執自增,然後根據操作碼分別處理。dispatch 函式的程式碼稍微長了一點。

def dispatch(self, op):
    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":       self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":     self.drop,
        "dup":      self.dup,
        "if":       self.if_stmt,
        "jmp":      self.jmp,
        "over":     self.over,
        "print":    self.print_,
        "println":  self.println,
        "read":     self.read,
        "stack":    self.dump_stack,
        "swap":     self.swap,
    }

    if op in dispatch_map:
        dispatch_map[op]()
    elif isinstance(op, int):
        # push numbers on the data stack
        self.push(op)
    elif isinstance(op, str) and op[0]==op[-1]=='"':
        # push quoted strings on the data stack
        self.push(op[1:-1])
    else:
        raise RuntimeError("Unknown opcode: '%s'" % op)

基本上,這段程式碼只是根據操作碼查詢是都有對應的處理函式,例如 * 對應 self.muldrop 對應 self.dropdup對應 self.dup。順便說一句,你在這裡看到的這段程式碼其實本質上就是簡單版的 Forth。而且,Forth 語言還是值得您看看的。

總之捏,它一但發現操作碼是 * 的話就直接呼叫 self.mul 並執行它。就像這樣:

def mul(self):
    self.push(self.pop() * self.pop())

其他的函式也是類似這樣的。如果我們在 dispatch_map 中查詢不到相應操作函式,我們首先檢查他是不是數字型別,如果是的話直接入棧;如果是被引號括起來的字串的話也是同樣處理--直接入棧。

截止現在,恭喜你,一個虛擬機器就完成了。

讓我們定義更多的操作,然後使用我們剛完成的虛擬機器和 p-code 語言來寫程式。

# Allow to use "print" as a name for our own method:
from __future__ import print_function

# ...

def plus(self):
    self.push(self.pop() + self.pop())

def minus(self):
    last = self.pop()
    self.push(self.pop() - last)

def mul(self):
    self.push(self.pop() * self.pop())

def div(self):
    last = self.pop()
    self.push(self.pop() / last)

def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()

def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()

讓我們用我們的虛擬機器寫個與 print((2+3)*4) 等同效果的例子。

Machine([2, 3, "+", 4, "*", "println"]).run() 你可以試著執行它。

現在引入一個新的操作 jump, 即 go-to 操作

def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")

它只改變指令指標的值。我們再看看分支跳轉是怎麼做的。

def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)

這同樣也是很直白的。如果你想要新增一個條件跳轉,你只要簡單的執行 test-value true-value false-value IF JMP 就可以了.(分支處理是很常見的操作,許多虛擬機器都提供類似 JNE 這樣的操作。JNEjump if not equal 的縮寫)。

下面的程式要求使用者輸入兩個數字,然後列印出他們的和和乘積。

Machine([
'"Enter a number: "', "print", "read", "cast_int",
'"Enter another number: "', "print", "read", "cast_int",
"over", "over",
'"Their sum is: "', "print", "+", "println",
'"Their product is: "', "print", "*", "println"
]).run()

overreadcast_int 這三個操作是長這樣滴:

def cast_int(self):
    self.push(int(self.pop()))

def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)

def read(self):
    self.push(raw_input())

以下這一段程式要求使用者輸入一個數字,然後列印出這個數字是奇數還是偶數。

Machine([
'"Enter a number: "', "print", "read", "cast_int",
'"The number "', "print", "dup", "print", '" is "', "print",
2, "%", 0, "==", '"even."', '"odd."', "if", "println",
0, "jmp" # loop forever!
]).run()

這裡有個小練習給你去實現:增加 callreturn 這兩個操作碼。call 操作碼將會做如下事情 :將當前地址推入返回堆疊中,然後呼叫 self.jmp()return 操作碼將會做如下事情:返回堆疊彈棧,將彈棧出來元素的值賦予指令指標(這個值可以讓你跳轉回去或者從 call 呼叫中返回)。當你完成這兩個命令,那麼你的虛擬機器就可以呼叫子例程了。

一個簡單的解析器

創造一個模仿上述程式的小型語言。我們將把它編譯成我們的機器碼。

 import tokenize
 from StringIO import StringIO

# ...

def parse(text):
tokens =   tokenize.generate_tokens(StringIO(text).readline)
for toknum, tokval, _, _, _ in tokens:
    if toknum == tokenize.NUMBER:
        yield int(tokval)
    elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
        yield tokval
    elif toknum == tokenize.ENDMARKER:
        break
    else:
        raise RuntimeError("Unknown token %s: '%s'" %
                (tokenize.tok_name[toknum], tokval))

一個簡單的優化:常量摺疊

常量摺疊(Constant folding)是窺孔優化(peephole optimization)的一個例子,也即是說再在編譯期間可以針對某些明顯的程式碼片段做些預計算的工作。比如,對於涉及到常量的數學表示式例如 2 3 +就可以很輕鬆的實現這種優化。

def constant_fold(code):
"""Constant-folds simple mathematical expressions like 2 3 + to 5."""
while True:
    # Find two consecutive numbers and an arithmetic operator
    for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
        if isinstance(a, int) and isinstance(b, int) \
                and op in {"+", "-", "*", "/"}:
            m = Machine((a, b, op))
            m.run()
            code[i:i+3] = [m.top()]
            print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
            break
    else:
        break
return code

採用常量摺疊遇到唯一問題就是我們不得不更新跳轉地址,但在很多情況這是很難辦到的(例如:test cast_int jmp)。針對這個問題有很多解決方法,其中一個簡單的方法就是隻允許跳轉到程式中的命名標籤上,然後在優化之後解析出他們真正的地址。

如果你實現了 Forth words,也即函式,你可以做更多的優化,比如刪除可能永遠不會被用到的程式程式碼(dead code elimination

REPL

我們可以創造一個簡單的 PERL,就像這樣

def repl():
print('Hit CTRL+D or type "exit" to quit.')

while True:
    try:
        source = raw_input("> ")
        code = list(parse(source))
        code = constant_fold(code)
        Machine(code).run()
    except (RuntimeError, IndexError) as e:
        print("IndexError: %s" % e)
    except KeyboardInterrupt:
        print("\nKeyboardInterrupt")

用一些簡單的程式來測試我們的 REPL

> 2 3 + 4 * println
Constant-folded 2+3 to 5
Constant-folded 5*4 to 20
20
> 12 dup * println
144
> "Hello, world!" dup println println
Hello, world!
Hello, world!
你可以看到,常量摺疊看起來運轉正常。在第一個例子中,它把整個程式優化成這樣 20 println。

下一步

當你新增完 callreturn 之後,你便可以讓使用者定義自己的函式了。在Forth 中函式被稱為 words,他們以冒號開頭緊接著是名字然後以分號結束。例如,一個整數平方的 word 是長這樣滴

: square dup * ;

實際上,你可以試試把這一段放在程式中,比如 Gforth

$ gforth
Gforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.
Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license'
Type `bye' to exit
: square dup * ;  ok
12 square . 144  ok

你可以在解析器中通過發現 : 來支援這一點。一旦你發現一個冒號,你必須記錄下它的名字及其地址(比如:在程式中的位置)然後把他們插入到符號表(symbol table)中。簡單起見,你甚至可以把整個函式的程式碼(包括分號)放在字典中,譬如:

symbol_table = {
"square": ["dup", "*"]
# ...
    }

當你完成了解析的工作,你可以連線你的程式:遍歷整個主程式並且在符號表中尋找自定義函式的地方。一旦你找到一個並且它沒有在主程式的後面出現,那麼你可以把它附加到主程式的後面。然後用 <address> call 替換掉 square,這裡的 <address> 是函式插入的地址。

為了保證程式能正常執行,你應該考慮剔除 jmp 操作。否則的話,你不得不解析它們。它確實能執行,但是你得按照使用者編寫程式的順序儲存它們。舉例來說,你想在子例程之間移動,你要格外小心。你可能需要新增 exit 函式用於停止程式(可能需要告訴作業系統返回值),這樣主程式就不會繼續執行以至於跑到子例程中。

實際上,一個好的程式空間佈局很有可能把主程式當成一個名為 main 的子例程。或者由你決定搞成什麼樣子。

如您所見,這一切都是很有趣的,而且通過這一過程你也學會了很多關於程式碼生成、連結、程式空間佈局相關的知識。

更多能做的事兒

你可以使用 Python 位元組碼生成庫來嘗試將虛擬機器程式碼為原生的 Python 位元組碼。或者用 Java 實現執行在 JVM 上面,這樣你就可以自由使用 JITing

同樣的,你也可以嘗試下register machine。你可以嘗試用棧幀(stack frames)實現呼叫棧(call stack),並基於此建立呼叫會話。

最後,如果你不喜歡類似 Forth 這樣的語言,你可以創造執行於這個虛擬機器之上的自定義語言。譬如,你可以把類似 (2+3)*4 這樣的中綴表示式轉化成 2 3 + 4 * 然後生成程式碼。你也可以允許 C 風格的程式碼塊 { ... } 這樣的話,語句 if ( test ) { ... } else { ... } 將會被翻譯成

<true/false test>
<address of true block>
<address of false block>
if
jmp

<true block>
<address of end of entire if-statement> jmp

<false block>
<address of end of entire if-statement> jmp

例子,

Address  Code
-------  ----
 0       2 3 >
 3       7        # Address of true-block
 4       11       # Address of false-block
 5       if
 6       jmp      # Conditional jump based on test

# True-block

7     "Two is greater than three."    
8       println
9       15       # Continue main program
10       jmp

# False-block ("else { ... }")
11       "Two is less than three."
12       println
13       15       # Continue main program
14       jmp

# If-statement finished, main program continues here
15       ...

對了,你還需要新增比較操作符 != < <= > >=

我已經在我的 C++ stack machine 實現了這些東東,你可以參考下。

我已經把這裡呈現出來的程式碼搞成了個專案 Crianza,它使用了更多的優化和實驗性質的模型來吧程式編譯成 Python 位元組碼。

祝好運!

完整的程式碼

下面是全部的程式碼,相容 Python 2 和 Python 3

你可以通過 這裡 得到它。

#!/usr/bin/env python
# coding: utf-8

"""
A simple VM interpreter.

Code from the post at http://csl.name/post/vm/
This version should work on both Python 2 and 3.
"""

from __future__ import print_function
from collections import deque
from io import StringIO
import sys
import tokenize


def get_input(*args, **kw):
"""Read a string from standard input."""
if sys.version[0] == "2":
    return raw_input(*args, **kw)
else:
    return input(*args, **kw)


class Stack(deque):
push = deque.append

def top(self):
    return self[-1]


class Machine:
def __init__(self, code):
    self.data_stack = Stack()
    self.return_stack = Stack()
    self.instruction_pointer = 0
    self.code = code

def pop(self):
    return self.data_stack.pop()

def push(self, value):
    self.data_stack.push(value)

def top(self):
    return self.data_stack.top()

def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)

def dispatch(self, op):
    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":       self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":     self.drop,
        "dup":      self.dup,
        "exit":     self.exit,
        "if":       self.if_stmt,
        "jmp":      self.jmp,
        "over":     self.over,
        "print":    self.print,
        "println":  self.println,
        "read":     self.read,
        "stack":    self.dump_stack,
        "swap":     self.swap,
    }

    if op in dispatch_map:
        dispatch_map[op]()
    elif isinstance(op, int):
        self.push(op) # push numbers on stack
    elif isinstance(op, str) and op[0]==op[-1]=='"':
        self.push(op[1:-1]) # push quoted strings on stack
    else:
        raise RuntimeError("Unknown opcode: '%s'" % op)

# OPERATIONS FOLLOW:

def plus(self):
    self.push(self.pop() + self.pop())

def exit(self):
    sys.exit(0)

def minus(self):
    last = self.pop()
    self.push(self.pop() - last)

def mul(self):
    self.push(self.pop() * self.pop())

def div(self):
    last = self.pop()
    self.push(self.pop() / last)

def mod(self):
    last = self.pop()
    self.push(self.pop() % last)

def dup(self):
    self.push(self.top())

def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)

def drop(self):
    self.pop()

def swap(self):
    b = self.pop()
    a = self.pop()
    self.push(b)
    self.push(a)

def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()

def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()

def read(self):
    self.push(get_input())

def cast_int(self):
    self.push(int(self.pop()))

def cast_str(self):
    self.push(str(self.pop()))

def eq(self):
    self.push(self.pop() == self.pop())

def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)

def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")

def dump_stack(self):
    print("Data stack (top first):")

    for v in reversed(self.data_stack):
        print(" - type %s, value '%s'" % (type(v), v))


def parse(text):
# Note that the tokenizer module is intended for parsing Python source
# code, so if you're going to expand on the parser, you may have to use
# another tokenizer.

if sys.version[0] == "2":
    stream = StringIO(unicode(text))
else:
    stream = StringIO(text)

tokens = tokenize.generate_tokens(stream.readline)

for toknum, tokval, _, _, _ in tokens:
    if toknum == tokenize.NUMBER:
        yield int(tokval)
    elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
        yield tokval
    elif toknum == tokenize.ENDMARKER:
        break
    else:
        raise RuntimeError("Unknown token %s: '%s'" %
                (tokenize.tok_name[toknum], tokval))

def constant_fold(code):
"""Constant-folds simple mathematical expressions like 2 3 + to 5."""
while True:
    # Find two consecutive numbers and an arithmetic operator
    for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
        if isinstance(a, int) and isinstance(b, int) \
                and op in {"+", "-", "*", "/"}:
            m = Machine((a, b, op))
            m.run()
            code[i:i+3] = [m.top()]
            print("Constant-folded %s%s%s to %s" %     (a,op,b,m.top()))
            break
        else:
            break
        return code

def repl():
print('Hit CTRL+D or type "exit" to quit.')

while True:
    try:
        source = get_input("> ")
        code = list(parse(source))
        code = constant_fold(code)
        Machine(code).run()
    except (RuntimeError, IndexError) as e:
        print("IndexError: %s" % e)
    except KeyboardInterrupt:
        print("\nKeyboardInterrupt")

def test(code = [2, 3, "+", 5, "*", "println"]):
print("Code before optimization: %s" % str(code))
optimized = constant_fold(code)
print("Code after optimization: %s" % str(optimized))

print("Stack after running original program:")
a = Machine(code)
a.run()
a.dump_stack()

print("Stack after running optimized program:")
b = Machine(optimized)
b.run()
b.dump_stack()

result = a.data_stack == b.data_stack
print("Result: %s" % ("OK" if result else "FAIL"))
return result

def examples():
print("** Program 1: Runs the code for `print((2+3)*4)`")
Machine([2, 3, "+", 4, "*", "println"]).run()

print("\n** Program 2: Ask for numbers, computes sum and product.")
Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"Enter another number: "', "print", "read", "cast_int",
    "over", "over",
    '"Their sum is: "', "print", "+", "println",
    '"Their product is: "', "print", "*", "println"
]).run()

print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).")
Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"The number "', "print", "dup", "print", '" is "', "print",
    2, "%", 0, "==", '"even."', '"odd."', "if", "println",
    0, "jmp" # loop forever!
]).run()


if __name__ == "__main__":
try:
    if len(sys.argv) > 1:
        cmd = sys.argv[1]
        if cmd == "repl":
            repl()
        elif cmd == "test":
            test()
            examples()
        else:
            print("Commands: repl, test")
    else:
        repl()
except EOFError:
    print("")

本文系OneAPM工程師編譯整理。OneAPM是中國基礎軟體領域的新興領軍企業,能幫助企業使用者和開發者輕鬆實現:緩慢的程式程式碼和SQL語句的實時抓取。想閱讀更多技術文章,請訪問OneAPM官方技術部落格

相關文章