最近我在用梯度下降演算法繪製神經網路的資料時,遇到了一些演算法效能的問題。梯度下降演算法的程式碼如下(虛擬碼):
1 2 3 |
def gradient_descent(): # the gradient descent code plotly.write(X, Y) |
一般來說,當網路請求 plot.ly 繪圖時會阻塞等待返回,於是也會影響到其他的梯度下降函式的執行速度。
一種解決辦法是每呼叫一次 plotly.write 函式就開啟一個新的執行緒,但是這種方法感覺不是很好。 我不想用一個像 cerely(一種分散式任務佇列)一樣大而全的任務佇列框架,因為框架對於我的這點需求來說太重了,並且我的繪圖也並不需要 redis 來持久化資料。
那用什麼辦法解決呢?我在 python 中寫了一個很小的任務佇列,它可以在一個單獨的執行緒中呼叫 plotly.write函式。下面是程式程式碼。
1 2 3 4 5 |
from threading import Thread import Queue import time class TaskQueue(Queue.Queue): |
首先我們繼承 Queue.Queue 類。從 Queue.Queue 類可以繼承 get 和 put 方法,以及佇列的行為。
1 2 3 4 |
def __init__(self, num_workers=1): Queue.Queue.__init__(self) self.num_workers = num_workers self.start_workers() |
初始化的時候,我們可以不用考慮工作執行緒的數量。
1 2 3 4 |
def add_task(self, task, *args, **kwargs): args = args or () kwargs = kwargs or {} self.put((task, args, kwargs)) |
我們把 task, args, kwargs 以元組的形式儲存在佇列中。*args 可以傳遞數量不等的引數,**kwargs 可以傳遞命名引數。
1 2 3 4 5 |
def start_workers(self): for i in range(self.num_workers): t = Thread(target=self.worker) t.daemon = True t.start() |
我們為每個 worker 建立一個執行緒,然後在後臺刪除。
下面是 worker 函式的程式碼:
1 2 3 4 5 6 |
def worker(self): while True: tupl = self.get() item, args, kwargs = self.get() item(*args, **kwargs) self.task_done() |
worker 函式獲取佇列頂端的任務,並根據輸入引數執行,除此之外,沒有其他的功能。下面是佇列的程式碼:
我們可以通過下面的程式碼測試:
1 2 3 4 5 6 7 8 9 10 11 12 |
def blokkah(*args, **kwargs): time.sleep(5) print “Blokkah mofo!” q = TaskQueue(num_workers=5) for item in range(1): q.add_task(blokkah) q.join() # wait for all the tasks to finish. print “All done!” |
Blokkah 是我們要做的任務名稱。佇列已經快取在記憶體中,並且沒有執行很多工。下面的步驟是把主佇列當做單獨的程式來執行,這樣主程式退出以及執行資料庫持久化時,佇列任務不會停止執行。但是這個例子很好地展示瞭如何從一個很簡單的小任務寫成像工作佇列這樣複雜的程式。
1 2 3 |
def gradient_descent(): # the gradient descent code queue.add_task(plotly.write, x=X, y=Y) |
修改之後,我的梯度下降演算法工作效率似乎更高了。如果你很感興趣的話,可以參考下面的程式碼。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from threading import Thread import Queue import time class TaskQueue(Queue.Queue): def __init__(self, num_workers=1): Queue.Queue.__init__(self) self.num_workers = num_workers self.start_workers() def add_task(self, task, *args, **kwargs): args = args or () kwargs = kwargs or {} self.put((task, args, kwargs)) def start_workers(self): for i in range(self.num_workers): t = Thread(target=self.worker) t.daemon = True t.start() def worker(self): while True: tupl = self.get() item, args, kwargs = self.get() item(*args, **kwargs) self.task_done() def tests(): def blokkah(*args, **kwargs): time.sleep(5) print "Blokkah mofo!" q = TaskQueue(num_workers=5) for item in range(10): q.add_task(blokkah) q.join() # block until all tasks are done print "All done!" if __name__ == "__main__": tests() |