技法3-2:當函式做為引數時的技巧

orion發表於2022-02-02

我們之前在Python技法3: 匿名函式、回撥函式、高階函式(連結:https://www.cnblogs.com/orion-orion/p/15427594.html)
中提到,可以通過lambda表示式來為函式設定預設引數,從而修改函式的引數個數:

import math
def distance(p1, p2):
    x1, y1 = p1
    x2, y2 = p2
    return math.hypot(x2 - x1, y2 - y1)
points = [(1, 2), (3, 4), (5, 6), (7, 8)]
pt = (4, 3)
points.sort(key=lambda p: distance(p, pt))
print(points)
# [(3, 4), (1, 2), (5, 6), (7, 8)]

下面我們在深度學習專案情境中展示下該手法的應用。
我們這是一個聯邦學習專案,有多個client客戶端,每個client中都有機器學習模型。我們現在有一份現有的祖傳程式碼不能改動,該祖傳程式碼中client類的初始化函式中需要傳入模型類的初始化函式和優化器類的初始化函式(注意,不是模型的物件和優化器的物件),然後在client類的建構函式中完成模型物件的初始化和優化器物件的初始化。
client類的建構函式部分如下所示如下:

class Client(FederatedTrainingDevice):
    def __init__(self, model_fn, optimizer_fn, batch_size=128, train_frac=0.8):
        self.model = model_fn().to(device)
        self.optimizer = optimizer_fn(self.model.parameters())
        ...

但是,這樣的程式碼就會面臨一個問題,modeloptimizer的初始化是需要超引數的,超引數如何傳進去?這時,就可以用我們前面所說的lambda技巧來解決了:

client = Client(lambda: ConvNet(input_size=input_sz, num_classes=num_cls), lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9)

這一還有一個注意的地方,就是lambda函式是可以無引數的,故可以用於完成我們這裡model_fn函式的替代。

相關文章