最近我在用梯度下降算法繪制神經網絡的數(shù)據(jù)時,遇到了一些算法性能的問題。梯度下降算法的代碼如下(偽代碼):
def gradient_descent(): # the gradient descent code plotly.write(X, Y)
一般來說,當網絡請求 plot.ly 繪圖時會阻塞等待返回,于是也會影響到其他的梯度下降函數(shù)的執(zhí)行速度。
一種解決辦法是每調用一次 plotly.write 函數(shù)就開啟一個新的線程,但是這種方法感覺不是很好。 我不想用一個像 cerely(一種分布式任務隊列)一樣大而全的任務隊列框架,因為框架對于我的這點需求來說太重了,并且我的繪圖也并不需要 redis 來持久化數(shù)據(jù)。
那用什么辦法解決呢?我在 python 中寫了一個很小的任務隊列,它可以在一個單獨的線程中調用 plotly.write函數(shù)。下面是程序代碼。
from threading import Thread import Queue import time class TaskQueue(Queue.Queue):
首先我們繼承 Queue.Queue 類。從 Queue.Queue 類可以繼承 get 和 put 方法,以及隊列的行為。
def __init__(self, num_workers=1): Queue.Queue.__init__(self) self.num_workers = num_workers self.start_workers()
初始化的時候,我們可以不用考慮工作線程的數(shù)量。
def add_task(self, task, *args, **kwargs): args = args or () kwargs = kwargs or {} self.put((task, args, kwargs))
我們把 task, args, kwargs 以元組的形式存儲在隊列中。*args 可以傳遞數(shù)量不等的參數(shù),**kwargs 可以傳遞命名參數(shù)。
def start_workers(self): for i in range(self.num_workers): t = Thread(target=self.worker) t.daemon = True t.start()
我們?yōu)槊總€ worker 創(chuàng)建一個線程,然后在后臺刪除。
下面是 worker 函數(shù)的代碼:
def worker(self): while True: tupl = self.get() item, args, kwargs = self.get() item(*args, **kwargs) self.task_done()
worker 函數(shù)獲取隊列頂端的任務,并根據(jù)輸入參數(shù)運行,除此之外,沒有其他的功能。下面是隊列的代碼:
我們可以通過下面的代碼測試:
def blokkah(*args, **kwargs): time.sleep(5) print “Blokkah mofo!” q = TaskQueue(num_workers=5) for item in range(1): q.add_task(blokkah) q.join() # wait for all the tasks to finish. print “All done!”
Blokkah 是我們要做的任務名稱。隊列已經緩存在內存中,并且沒有執(zhí)行很多任務。下面的步驟是把主隊列當做單獨的進程來運行,這樣主程序退出以及執(zhí)行數(shù)據(jù)庫持久化時,隊列任務不會停止運行。但是這個例子很好地展示了如何從一個很簡單的小任務寫成像工作隊列這樣復雜的程序。
def gradient_descent(): # the gradient descent code queue.add_task(plotly.write, x=X, y=Y)
修改之后,我的梯度下降算法工作效率似乎更高了。如果你很感興趣的話,可以參考下面的代碼。
from threading import Thread import Queue import time class TaskQueue(Queue.Queue): def __init__(self, num_workers=1): Queue.Queue.__init__(self) self.num_workers = num_workers self.start_workers() def add_task(self, task, *args, **kwargs): args = args or () kwargs = kwargs or {} self.put((task, args, kwargs)) def start_workers(self): for i in range(self.num_workers): t = Thread(target=self.worker) t.daemon = True t.start() def worker(self): while True: tupl = self.get() item, args, kwargs = self.get() item(*args, **kwargs) self.task_done() def tests(): def blokkah(*args, **kwargs): time.sleep(5) print "Blokkah mofo!" q = TaskQueue(num_workers=5) for item in range(10): q.add_task(blokkah) q.join() # block until all tasks are done print "All done!" if __name__ == "__main__": tests()
更多文章、技術交流、商務合作、聯(lián)系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯(lián)系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
