Python技法3:匿名函式、回撥函式和高階函式

lonelyprince7發表於2021-10-20

1、定義匿名或行內函數

如果我們想提供一個短小的回撥函式供sort()這樣的函式用,但不想用def這樣的語句編寫一個單行的函式,我們可以藉助lambda表示式來編寫“內聯”式的函式。如下圖所示:

add = lambda x, y: x + y
print(add(2, 3)) # 5
print(add("hello", "world!")) # helloworld

可以看到,這裡用到的lambda表示式和普通的函式定義有著相同的功能。
lambda表示式常常做為回撥函式使用,有在排序以及對資料進行預處理時有許多用武之地,如下所示:

names = [ 'David Beazley', 'Brian Jones', 'Reymond Hettinger', 'Ned Batchelder']
sorted_names = sorted(names, key=lambda name: name.split()[-1].lower())
print(sorted_names)
# ['Ned Batchelder', 'David Beazley', 'Reymond Hettinger', 'Brian Jones']

lambda雖然靈活易用,但是侷限性也大,相當於其函式體中只能定義一條語句,不能執行條件分支、迭代、異常處理等操作。

2、在匿名函式中繫結變數的值

現在我們想在匿名函式定義時完成對特定變數(一般是常量)的繫結,以便後期使用。如果我們這樣寫:

x = 10
a = lambda y: x + y 
x = 20
b = lambda y: x + y

然後計算a(10)b(10)。你可能希望結果是2030,然而實際程式的執行結果會出人意料:結果是3030
這個問題的關鍵在於lambda表示式中的x是個自由變數(未繫結到本地作用域的變數),在執行時繫結而不是定義的時候繫結(其實普通函式中使用自由變數同理),而這裡執行a(10)的時候x已經變成了20,故最終a(10)的值為30。如果希望匿名函式在定義的時候繫結變數,而之後繫結值不再變化,那我們可以將想要繫結的變數做為預設引數,如下所示:

x = 10
a = lambda y, x=x: x + y
x = 20
b = lambda y, x=x: x + y
print(a(10)) # 20
print(b(10)) # 30

上面我們提到的這個陷阱常見於一些對lambda函式過於“聰明”的應用中。比如我們想用列表推導式來建立一個列表的lambda函式並期望lambda函式能記住迭代變數。

funcs = [lambda x: x + n for n in range(5)]
for f in funcs:
    print(f(0))
# 4
# 4
# 4
# 4
# 4

可以看到與我們期望的不同,所有lambda函式都認為n是4。如上所述,我們修改成以下程式碼即可:

funcs = [lambda x, n=n: x + n for n in range(5)]
for f in funcs:
    print(f(0))
# 0
# 1
# 2
# 3
# 4

2、讓帶有n個引數的可呼叫物件以較少的引數呼叫

假設我們現在有個n個引數的函式做為回撥函式使用,但這個函式需要的引數過多,而回撥函式只能有個引數。如果需要減少函式的引數數量,需要時用functools包。functools這個包內的函式全部為高階函式。高階函式即引數或(和)返回值為其他函式的函式。通常來說,此模組的功能適用於所有可呼叫物件。
比如functools.partial()就是一個高階函式, 它的原型如下:

functools.partial(func, /, *args, **keywords)

它接受一個func函式做為引數,並且它會返回一個新的newfunc物件,這個新的newfunc物件已經附帶了位置引數args和關鍵字引數keywords,之後在呼叫newfunc時就可以不用再傳已經設定好的引數了。如下所示:

def spam(a, b, c, d):
  print(a, b, c, d)

from functools import partial
s1 = partial(spam, 1) # 設定好a = 1(如果沒指定引數名,預設按順序設定)
s1(2, 3, 4) # 1 2 3 4

s2 = partial(spam, d=42) # 設定好d為42
s2(1, 2, 3) # 1 2 3 42

s3 = partial(spam, 1, 2, d=42) #設定好a = 1, b = 2, d = 42
s3(3) # 1 2 3 42

上面提到的技術常常用於將不相容的程式碼“粘”起來,尤其是在你呼叫別人的輪子,而別人寫好的函式不能修改的時候。比如我們有以下一組元組表示的點的座標:

points = [(1, 2), (3, 4), (5, 6), (7, 8)]

有已知的一個distance()函式可供使用,假設這是別人造的輪子不能修改。

import math
def distance(p1, p2):
    x1, y1 = p1
    x2, y2 = p2
    return math.hypot(x2 - x1, y2 - y1)

接下來我們想根據列表中這些點到一個定點pt=(4, 3)的距離來排序。我們知道列表的sort()方法
可以接受一個key引數(傳入一個回撥函式)來做自定義的排序處理。但傳入的回撥函式只能有一個引數,這裡的distance()函式有兩個引數,顯然不能直接做為回撥函式使用。下面我們用partical()來解決這個問題:

pt = (4, 3)
points.sort(key=partial(distance, pt)) # 先指定好一個引數為pt=(4,3)
print(points)
# [(3, 4), (1, 2), (5, 6), (7, 8)]

可以看到,排序正確執行。還有一種方法要臃腫些,那就是將回撥函式distance巢狀進另一個只有一個引數的lambda函式中:

pt = (4, 3)
points.sort(key=lambda p: distance(p, pt))
print(points)
# [(3, 4), (1, 2), (5, 6), (7, 8)]

這種方法一來臃腫,二來仍然存在我們上面提到過的一個毛病,如果我們定義回撥函式後對pt有所修改,就會發生我們上面所說的不愉快的事情:

pt = (4, 3)
func_key = lambda p: distance(p ,pt) 
pt = (0, 0) # 像這樣,後面pt變了就GG
points.sort(key=func_key)
print(points)
# [(1, 2), (3, 4), (5, 6), (7, 8)]

可以看到,最終排序的結果由於後面pt的改變而變得完全不同了。所以我們還是建議大家採用使用functools.partial()函式來達成目的。
下面這段程式碼也是用partial()函式來調整函式簽名的例子。這段程式碼利用multiprocessing模組以非同步方式計算某個結果,然後用一個回撥函式來列印該結果,該回撥函式可接受這個結果和一個事先指定好的日誌引數。

# result:回撥函式本身該接受的引數, log是我想使其擴充套件的引數
def output_result(result, log=None):
    if log is not None:
        log.debug('Got: %r', result)

def add(x, y):
    return x + y

if __name__ == '__main__':
    import logging
    from multiprocessing import Pool
    from functools import partial
    logging.basicConfig(level=logging.DEBUG)
    log = logging.getLogger('test')
    p = Pool()
    p.apply_async(add, (3, 4), callback=partial(output_result, log=log))
    p.close()
    p.join()

# DEBUG:test:Got: 7

下面這個例子則源於一個在編寫網路伺服器中所面對的問題。比如我們在socketServer模組的基礎上,編寫了下面這個簡單的echo服務程式:

from socketserver import StreamRequestHandler, TCPServer
class EchoHandler(StreamRequestHandler):
    def handle(self):
        for line in self.rfile:
            self.wfile.write(b'GoT:' + line)

serv = TCPServer(('', 15000), EchoHandler)
serv.serve_forever()

現在,我們想在EchoHandler類中增加一個__init__()方法,它接受額外的一個配置引數,用於事先指定ack。即:

class EchoHandler(StreamRequestHandler):
    def __init__(self, *args, ack, **kwargs):
        self.ack = ack
        super().__init__(*args, **kwargs) 
    def handle(self) -> None:
        for line in self.rfile:
            self.wfile.write(self.ack + line)

假如我們就這樣直接改動,就會發現後面會提示__init__()函式缺少keyword-only引數ack(這裡呼叫EchoHandler()初始化物件的時候會隱式呼叫__init__()函式)。 我們用partical()也能輕鬆解決這個問題,即為EchoHandler()事先提供好ack引數。

from functools import partial
serv = TCPServer(('', 15000), partial(EchoHandler, ack=b'RECEIVED'))
serv.serve_forever()

3、在回撥函式中攜帶額外的狀態

我們知道,我們呼叫回撥函式後,就會跳轉到一個全新的環境,此時會丟失我們原本的環境狀態。接下來我們討論如何在回撥函式中攜帶額外的狀態以便在回撥函式內部使用。
因為對回撥函式的應用在與非同步處理相關的庫和框架中比較常見,我們下面的例子也多和非同步處理相關。現在我們定義了一個非同步處理函式,它會呼叫一個回撥函式。

def apply_async(func, args, *, callback):
    # 計算結果
    result = func(*args)
    # 將結果傳給回撥函式
    callback(result)

下面展示上述程式碼如何使用:

# 要回撥的函式
def print_result(result):
    print("Got: ", result)
    
def add(x, y):
    return x + y

apply_async(add, (2, 3), callback=print_result)
# Got: 5
apply_async(add, ('hello', 'world'), callback=print_result)
# Got: helloworld

現在我們希望回撥函式print_reuslt()能夠接受更多的引數,比如其他變數或者環境狀態資訊。比如我們想讓print_result()函式每次的列印資訊都包括一個序列號,以表示這是第幾次被呼叫,如[1] ...[2] ...這樣。首先我們想到,可以用額外的引數在回撥函式中攜帶狀態,然後用partial()來處理引數個數問題:

class SequenceNo:
    def __init__(self) -> None:
        self.sequence = 0

def handler(result, seq):
    seq.sequence += 1
    print("[{}] Got: {}".format(seq.sequence, result))

seq = SequenceNo()
from functools import partial
apply_async(add, (2, 3), callback=partial(handler, seq=seq)) 
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=partial(handler, seq=seq))
# [2] Got: helloworld

看起來整個程式碼有點鬆散繁瑣,我們有沒有什麼更簡潔緊湊的方法能夠處理這個問題呢?答案是直接使用和其他類繫結的方法(bound-method)。比如面這段程式碼就將print_result做為一個類的方法,這個類儲存了計數用的ack序列號,每當呼叫print_reuslt()列印一個結果時就遞增1:

class ResultHandler:
    def __init__(self) -> None:
        self.sequence = 0
    def handler(self, result):
        self.sequence += 1
        print("[{}] Got: {}".format(self.sequence, result))

apply_async(add, (2, 3), callback=r.handler) 
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=r.handler) 
# [2] Got: helloworld

還有一種實現方法是使用閉包,這種方法和使用類繫結方法相似。但閉包更簡潔優雅,執行速度也更快:

def make_handler():
    sequence = 0
    def handler(result):
        nonlocal sequence # 在閉包中編寫函式來修改內層變數,需要用nonlocal宣告
        sequence += 1
        print("[{}] Got: {}".format(sequence, result))
    return handler

handler = make_handler()
apply_async(add, (2, 3), callback=handler) 
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=handler) 
# [2] Got: helloworld

最後一種方法,則是利用協程(coroutine)來完成同樣的任務:

def make_handler_cor():
    sequence = 0
    while True:
        result = yield
        sequence += 1
        print("[{}] Got: {}".format(sequence, result))

handler = make_handler_cor()
next(handler) # 切記在yield之前一定要加這一句
apply_async(add, (2, 3), callback=handler.send) #對於協程來說,可以使用它的send()方法來做為回撥函式
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=handler.send)
# [2] Got: helloworld

參考文獻

  • [1] https://www.python.org/
  • [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.

相關文章