用於資料科學的幾種Python裝飾器介紹 - Bytepawn

banq發表於2022-05-24

在這篇文章中,我將展示一些@decorators可能對資料科學家有用的東西:

@parallel
讓我們假設我寫了一個非常低效的方法來尋找素數:

from sympy import isprime

def generate_primes(domain: int=1000*1000, num_attempts: int=1000) -> list[int]:
    primes: set[int] = set()
    seed(time())
    for _ in range(num_attempts):
        candidate: int = randint(4, domain)
        if isprime(candidate):
            primes.add(candidate)
    return sorted(primes)

print(len(generate_primes()))


輸出:88

然後我意識到,如果我在所有的CPU執行緒上並行執行原來的generate_primes(),我可以得到一個 "免費 "的加速。這是很常見的,定義一個@parallel用法:

def parallel(func=None, args=(), merge_func=lambda x:x, parallelism = cpu_count()):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            results = Parallel(n_jobs=parallelism)(delayed(func)(*args, **kwargs) for i in range(parallelism))
            return merge_func(results)
        return inner
    if func is None:
        # decorator was used like @parallel(...)
        return decorator
    else:
        # decorator was used like @parallel, without parens
        return decorator(func)


有了這個,只需一行,我們就可以將我們的函式並行化。

@parallel(merge_func=lambda li: sorted(set(chain(*li))))
def generate_primes(...): # same signature, nothing changes
    ... # same code, nothing changes

print(len(generate_primes()))


輸出:1281

在我的例子中,我的Macbook有8個核心,16個執行緒(cpu_count()是16),所以我產生了16倍的素數。

注意:
唯一的開銷是必須定義一個merge_func,它將函式的不同執行結果合併為一個結果,以便向裝飾函式(本例中為 generate_primes())的外部呼叫者隱藏並行性。在這個玩具例子中,我只是合併了列表,並透過使用 set() 確保素數是唯一的。
有許多Python庫和方法(例如執行緒與程式)可以實現並行。
這個例子使用了joblib.Parallel()的程式並行,它在Darwin + python3 + ipython上執行良好,並且避免了對Python全域性直譯器鎖(GIL)的鎖定。

@production
有時候,我們寫了一個複雜的管道,有一些額外的步驟,我們只想在某些環境下執行。例如,在我們的本地開發環境中做一些事情,但在生產環境中不做,反之亦然。如果能夠對函式進行裝飾,讓它們只在某些環境下執行,而在其他地方不做任何事情,那就更好了。

實現這一目標的方法之一是使用一些簡單的裝飾器。@production表示我們只想在prod上執行的東西,@development表示我們只想在dev中執行的東西,我們甚至可以引入一個@inactive,將函式完全關閉。這種方法的好處是,這種方式可以在程式碼/Github中跟蹤部署歷史和當前狀態。另外,我們可以在一行中做出這些改變,從而使提交更簡潔;例如,@inactive比整個程式碼塊被註釋掉的大提交要乾淨。

production_servers = [...]

def production(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is not a production server, skipping function decorated with @production...')
    return inner

def development(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() not in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is a production server, skipping function decorated with @development...')
    return inner

def inactive(func: Callable):
    def inner(*args, **kwargs):
        print('Skipping function decorated with @inactive...')
    return inner

@production
def foo():
    print('Running in production, touching databases!')

foo()

@development
def foo():
    print('Running in production, touching databases!')

foo()

@inactive
def foo():
    print('Running in production, touching databases!')

foo()


輸出:

Running in production, touching databases!
This host is a production server, skipping function decorated with @development...
Skipping function decorated with @inactive...

這個想法可以適用於其他框架/環境。

@deployable
在我目前的工作中,我們使用Airflow進行ETL/資料管道。我們有一個豐富的輔助函式庫,可以在內部構建適當的DAG,所以使用者(資料科學家)不必擔心這個問題。

最常用的是dag_vertica_create_table_as(),它在我們的Vertica DWH上執行一個SELECT,每晚將結果轉儲到一個表中。

dag = dag_vertica_create_table_as(
    table='my_aggregate_table',
    owner='Marton Trencseni (marton.trencseni@maf.ae)',
    schedule_interval='@daily',
    ...
    select="""
    SELECT
        ...
    FROM
        ...
    """
)

然後這就變成了對DWH的查詢,大致是這樣:

CREATE TABLE my_aggregate_table AS
SELECT ...



實際上,情況更復雜:我們首先執行今天的查詢,如果今天的查詢被成功建立,則有條件地刪除昨天的查詢。這個條件邏輯(以及其他一些針對我們環境的意外的複雜性,比如必須釋出GRANTs)導致DAG有9個步驟,但這不是這裡的重點,也超出了本文的範圍。

在過去的兩年裡,我們已經建立了近500個DAG,所以我們擴大了Airflow EC2例項的規模,並引入了獨立的開發和生產環境。如果能有一種方法來標記DAG是應該在開發環境還是生產環境中執行,在程式碼/Github中跟蹤這一點,並使用相同的機制來確保DAG不會意外地執行在錯誤的環境中,那就更好了。

大約有10個類似的便利函式,如dag_vertica_create_or_replace_view_as()和dag_vertica_train_predict_model()等,我們希望這些dag_xxx()函式的所有呼叫都可以在生產和開發之間切換(或者到處跳過)。

然而,上一節中的@production和@development裝飾器在這裡不起作用,因為我們不想將dag_vertica_create_table_as()切換為永遠不在其中一個環境中執行。我們希望能夠在每次呼叫時進行設定,並且在我們所有的dag_xxxx()函式中都有這個功能,而不需要複製/貼上程式碼。我們想要的是在我們所有的dag_xxxx()函式中新增一個部署引數(有一個好的預設值),這樣我們就可以在我們的DAG中新增這個引數,以增加安全性。我們可以透過@deployable裝飾器來實現這個目標。

def deployable(func):
    def inner(*args, **kwargs):
        if 'deploy' in kwargs:
            if kwargs['deploy'].lower() in ['production', 'prod'] and gethostname() not in production_servers:
                print('This host is not a production server, skipping...')
                return
            if kwargs['deploy'].lower() in ['development', 'dev'] and gethostname() not in development_servers:
                print('This host is not a development server, skipping...')
                return
            if kwargs['deploy'].lower() in ['skip', 'none']:
                print('Skipping...')
                return
            del kwargs['deploy'] # to avoid func() throwing an unexpected keyword exception
        return func(*args, **kwargs)
    return inner


然後,我們可以將裝飾器新增到我們的函式定義中(每個函式新增1行)。

@deployable
def dag_vertica_create_table_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_create_or_replace_view_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_train_predict_model(...): # same signature, nothing changes
    ... # code signature, nothing changes


如果我們在這裡停止,什麼也不會發生,我們不會破壞任何東西。
然而,現在我們可以到我們使用這些函式的DAG檔案中,增加1行。

dag = dag_vertica_create_table_as(
    deploy='development', # the function will return None on production
    ...
)


@redirect (stdout)
有時我們寫一個大的函式,也會呼叫其他程式碼,各種資訊都會被列印()出來。或者,我們可能有一個bug,有一堆print(),想在列印出來的內容上加上行號,這樣就可以更容易地參考它們。在這些情況下,@redirect可能是有用的。這個裝飾器將print()的標準輸出重定向到我們自己的逐行印表機,我們可以對它做任何我們想做的事情(包括扔掉它)。

def redirect(func=None, line_print: Callable = None):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            with StringIO() as buf, redirect_stdout(buf):
                func(*args, **kwargs)
                output = buf.getvalue()
            lines = output.splitlines()
            if line_print is not None:
                for line in lines:
                    line_print(line)
            else:
                width = floor(log(len(lines), 10)) + 1
                for i, line in enumerate(lines):
                    i += 1
                    print(f'{i:0{width}}: {line}')
        return inner
    if func is None:
        # decorator was used like @redirect(...)
        return decorator
    else:
        # decorator was used like @redirect, without parens
        return decorator(func)



如果我們使用redirect()而不指定明確的line_print()函式,它就會列印行數,但要加上行號。

@redirect
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(10)

Output:

01: Line #1
02: Line #2
03: Line #3
04: Line #4
05: Line #5
06: Line #6
07: Line #7
08: Line #8
09: Line #9
10: Line #10



如果我們想把所有的列印文字儲存到一個變數中,我們也可以實現這一點。

lines = []
def save_lines(line):
    lines.append(line)

@redirect(line_print=save_lines)
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(3)
print(lines)


Output:

<p class="indent">['Line #1', 'Line #2', 'Line #3']


重定向stdout的實際工作是由contextlib.redirect_stdout完成的。


@stacktrace
下一個裝飾器模式是@stacktrace,當函式被呼叫和從函式返回值時,它會發出有用的資訊。

def stacktrace(func=None, exclude_files=['anaconda']):
    def tracer_func(frame, event, arg):
        co = frame.f_code
        func_name = co.co_name
        caller_filename = frame.f_back.f_code.co_filename
        if func_name == 'write':
            return # ignore write() calls from print statements
        for file in exclude_files:
            if file in caller_filename:
                return # ignore in ipython notebooks
        args = str(tuple([frame.f_locals[arg] for arg in frame.f_code.co_varnames]))
        if args.endswith(',)'):
            args = args[:-2] + ')'
        if event == 'call':
            print(f'--> Executing: {func_name}{args}')
            return tracer_func
        elif event == 'return':
            print(f'--> Returning: {func_name}{args} -> {repr(arg)}')
        return
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            settrace(tracer_func)
            func(*args, **kwargs)
            settrace(None)
        return inner
    if func is None:
        # decorator was used like @stacktrace(...)
        return decorator
    else:
        # decorator was used like @stacktrace, without parens
        return decorator(func)



有了這個,我們就可以裝飾我們希望追蹤開始的最上面的函式,我們將得到關於分支的有用的輸出。

def b():
    print('...')

@stacktrace
def a(arg):
    print(arg)
    b()
    return 'world'

a('foo')
Output:

--> Executing: a('foo')
foo
--> Executing: b()
...
--> Returning: b() -> None
--> Returning: a('foo') -> 'world'


這裡唯一的訣竅是。在我的例子中,我在Anaconda上的ipython中執行這段程式碼,所以我隱藏了程式碼在路徑中有Anaconda的檔案中的部分呼叫棧(否則我在上面的片段中會得到大約50-100個無用的呼叫棧條目)。這是透過裝飾器的exclude_files引數完成的。


@traceclass
與上述類似,我們可以定義一個裝飾器@traceclass,與類一起使用,以獲得其成員的執行軌跡。這包括在之前的裝飾器帖子中,在那裡它只是被稱為@trace,並且有一個bug(在原來的帖子中已經修復)。這個裝飾器。

def traceclass(cls: type):
    def make_traced(cls: type, method_name: str, method: Callable):
        def traced_method(*args, **kwargs):
            print(f'--> Executing: {cls.__name__}::{method_name}()')
            return method(*args, **kwargs)
        return traced_method
    for name in cls.__dict__.keys():
        if callable(getattr(cls, name)) and name != '__class__':
            setattr(cls, name, make_traced(cls, name, getattr(cls, name)))
    return cls


使用:

@traceclass
class Foo:
    i: int = 0
    def __init__(self, i: int = 0):
        self.i = i
    def increment(self):
        self.i += 1
    def __str__(self):
        return f'This is a {self.__class__.__name__} object with i = {self.i}'

f1 = Foo()
f2 = Foo(4)
f1.increment()
print(f1)
print(f2)
Output:

--> Executing: Foo::__init__()
--> Executing: Foo::__init__()
--> Executing: Foo::increment()
--> Executing: Foo::__str__()
This is a Foo object with i = 1
--> Executing: Foo::__str__()
This is a Foo object with i = 4

 

相關文章