『Pythonのレベルを上げる12のデコレータ』メモ

はじめに

このような記事を見つけました。
functools等で提供されているデコレータではなく、使用頻度の高いデコレータを独自に実装しようという試みの記事だそうです。
Pythonのレベルを上げるために良さそうなモノを上げ、自分なりにコードを書き換えてメモりました。
towardsdatascience.com

logger

関数のログを出力できるデコレータです。
自身で定義した関数の上につけることで関数にどういった引数が入力されて実行されたのかなどを出力することができます。

def logger(function):
    def wrapper(*args, **kwargs):
        # args: ('aaaa',)
        arg_str: str = ",".join(args)
        # kwargs: {'endline': '\n'}
        kwargs_str: str = ",".join(f"{k}={v}" for k, v in kwargs.items())
        executed_function_name: str = f"{function.__name__}({arg_str}, {kwargs_str})"
        print(f"----- {executed_function_name}: start -----")
        output = function(*args, **kwargs)
        print(f"----- {executed_function_name}: end -----")
        return output

    return wrapper

@logger
def print_text(text: str, endline) -> None:
    print(text, end=endline)
    print()

print_text("aaaa", endline="1")

出力結果

----- print_text(aaaa, endline=1): start -----
aaaa1
----- print_text(aaaa, endline=1): end -----

関数の引数*argsargsはtuple型, **kwargskwargsはdict型です。
下記の処理にてデコレータをつけた関数に入力された引数を追跡することができます。

arg_str: str = ",".join(args)
kwargs_str: str = ",".join(f"{k}={v}" for k, v in kwargs.items())
executed_function_name: str = f"{function.__name__}({arg_str}, {kwargs_str})"

wraps

下記は関数add_two_numbers()にデコレータloggerをつけている様子です。
関数add_two_numbers(a,b)は、デコレータの性質上、実際の機能はlogger(add_two_numbers())です。
実行している部分はwrapper(*args, **kwargs)です。

def logger(function):
    def wrapper(*args, **kwargs):
        """wrapper documentation"""
        print(f"----- {function.__name__}: start -----")
        output = function(*args, **kwargs)
        print(f"----- {function.__name__}: end -----")
        return output
    return wrapper

@logger
def add_two_numbers(a, b):
    """this function adds two numbers"""
    return a + b

したがって、add_two_numbers()関数の属性を取得するとデコレータで定義されたものが取得されてしまいます。
これはadd_two_numbers()の情報を取得しているように見えますが、デコレータ内部の情報を取得していることになり直感に反します。

add_two_numbers.__name__
'wrapper'
add_two_numbers.__doc__
'wrapper documentation'

@wraps(function) を宣言すると無事もとの関数function通りの属性を取得することができます。

from functools import wraps

def logger(function):
    @wraps(function)
    def wrapper(*args, **kwargs):
        """wrapper documentation"""
        print(f"----- {function.__name__}: start -----")
        output = function(*args, **kwargs)
        print(f"----- {function.__name__}: end -----")
        return output
    return wrapper

@logger
def add_two_numbers(a, b):
    """this function adds two numbers"""
    return a + b
add_two_numbers.__name__
'add_two_numbers'
add_two_numbers.__doc__
'this function adds two numbers'

lru_cache

関数のメモ化をサポートするデコレータです。
下記のようなフィボナッチ数を求める関数で、既に計算した値をキャッシュしておくことができます。
lru_cacheをつけない場合、実行に4.8秒かかります。

import time

def fib(n):
    return n if n < 2 else fib(n-1) + fib(n-2)

begin = time.time()

for n in range(32):
    print(n, fib(n))

end = time.time()
print(end - begin)

一方でfib()関数に@lru_cacheをつけると8.4e-05秒で実行することができます。

import time
from functools import lru_cache

@lru_cache
def fib(n):
    return n if n < 2 else fib(n-1) + fib(n-2)


begin = time.time()

for n in range(36):
    print(n, fib(n))

end = time.time()
print(end - begin)

実行結果

0 0
1 1
2 1
3 2
(略)
34 5702887
35 9227465
8.487701416015625e-05

fib.cache_info()でヒット件数やミスした件数などを参照できるそうです。
また、@lru_cache(maxsize=N)とすることで最大のキャッシュ保存件数をN件に指定することができるそうです。
maxsize=Noneにすると通常の@cacheと同じになり、制限なしにキャッシュができるそうです。
docs.python.org

これを独自実装すると次のようになります。

from functools import wraps
import random
import time

def cache(function):
    @wraps(function)
    def wrapper(*args, **kwargs):
        cache_key = args + tuple(kwargs.items())  # メモ化の辞書に記録するために(5, ('a', 3))のようなタプルを作る
        if cache_key in wrapper.cache:  # メモがあれば
            output = wrapper.cache[cache_key]
        else:
            output = function(*args, **kwargs)  # 関数実行
            wrapper.cache[cache_key] = output  # メモ化
        return output
    wrapper.cache = dict()  # メモを初期化
    return wrapper

@cache
def heavy_processing(n, a):
    sleep_time = n + random.random()
    time.sleep(sleep_time)


heavy_processing(5, a=3)

repeat

関数をN回繰り返すようにするデコレータです。

def repeat(number_of_times):  # デコレータの引数
    def decorate(func):  # デコレータの引数の関数
        @wraps(func)  # 関数の情報がwrapperにならないようにする
        def wrapper(*args, **kwargs):  # 関数の引数
            for _ in range(number_of_times):   # 関数を繰り返す
                func(*args, **kwargs)
        return wrapper
    return decorate
@repeat(5)
def dummy():
    print("hello")

dummy()
# hello
# hello
# hello
# hello
# hello

timeit

デコレータをつけた関数の実行時間を表示します。
デコレータ内のwrapper()にて元の関数を実行し、その関数の前後にstart, endで現在時刻を取得する変数を置きend-startの時刻を表示すればよいです。

import time
from functools import wraps

# このデコレータをつけた関数の実行時間を表示する
def timeit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f'{func.__name__} took {end - start:.6f} seconds to complete')
        return result
    return wrapper

@timeit
def process_data():
    time.sleep(1)

process_data()
# process_data took 1.001530 seconds to complete

retry

このデコレータをつけた関数はエラーが起きた場合、繰り返しが発生します。

import random
import time
from functools import wraps

def retry(num_retries, exception_to_check, sleep_time=0):
    def decorate(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for i in range(1, num_retries+1):
                try:
                    return func(*args, **kwargs)
                except exception_to_check as e:
                    print(f"{func.__name__} raised {e.__class__.__name__}. Retrying...")
                    if i < num_retries:
                        time.sleep(sleep_time)
            raise e
        return wrapper
    return decorate

# 最大3回まで繰り返される
# exception_to_check だった場合、繰り返される
# 間隔は1秒ごとである
@retry(num_retries=3, exception_to_check=ValueError, sleep_time=1)
def random_value():
    value = random.randint(1, 5)
    if value == 3:
        raise ValueError("Value cannot be 3")
    return value

countcall

デコレータをつけた関数が何回コールされたかを記録するデコレータです。

from functools import wraps

def countcall(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        wrapper.count += 1  # インクリメント
        result = func(*args, **kwargs)
        print(f'{func.__name__} has been called {wrapper.count} times')
        return result
    wrapper.count = 0  # wrapper.count を0に初期化
    return wrapper

@countcall
def process_data():
    pass

process_data()
# process_data has been called 1 times
process_data()
# process_data has been called 2 times
process_data()
# process_data has been called 3 times

デコレータをつけた関数の前でwrapperに初期値0のプロパティをつけ、関数の実行ごとにwrapper.countをインクリメントすることで実行された回数を記録しています。
デコレータ定義の中のwrapper()以外は宣言時初回しか実行されないことに注意です。

def countcall(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        wrapper.count += 1
        result = func(*args, **kwargs)
        print(f'{func.__name__} has been called {wrapper.count} times')
        return result
    print("-" * 30)
    wrapper.count = 0
    return wrapper

@countcall
def process_data():
    pass

print("#1")
process_data()
print("#2")
process_data()
print("#3")
process_data()

この実行結果は下記のとおり、------が1回のみ表示されます。
さらに初回関数実行前(process_data()宣言時)に-----があることに注意です。

------------------------------
#1
process_data has been called 1 times
#2
process_data has been called 2 times
#3
process_data has been called 3 times

また、Pythonの関数型(class 'function')には任意のプロパティを付けることが出来ます。

def func():
    pass

print("count" in dir(func))  # False
func.count = 0
print("count" in dir(func))  # True

rate_limited

既にrate_limitedアルゴリズムを実現するデコレータがあります。
こちらを使うと下記のように実装ができます。
call_api()関数は引数urlにGETリクエストを送りレスポンスを受け取る関数です。
call_api()関数に@limit(calls=3, period=10)のデコレータを付与するとcall_api()関数は10秒間に3回のみ実行することが許されます。
それを超過してしまうとratelimit.RateLimitExceptionが発生します。

from datetime import datetime
import requests
from ratelimit import limits, sleep_and_retry

@limits(calls=3, period=10)
def call_api(url):
    response = requests.get(url)
    if response.status_code != 200:
        raise Exception("API response: {}".format(response.status_code))
    return response

url = "https://cat-fact.herokuapp.com/facts/"
for i in range(10):
    response = call_api(url)
    print(i, datetime.now().isoformat(), response.status_code)

実行結果は下記の通りです。

0 2023-04-10T01:13:47.626372 200
1 2023-04-10T01:13:48.500790 200
2 2023-04-10T01:13:49.392809 200
Traceback (most recent call last):
  File "/home/gesogeso/src/tmp/endpoint.py", line 18, in <module>
    response = call_api(url)
  File "/home/gesogeso/src/tmp/.tmp/lib/python3.10/site-packages/ratelimit/decorators.py", line 77, in wrapper
    raise RateLimitException('too many calls', period_remaining)
ratelimit.exception.RateLimitException: too many calls

@sleep_and_retryを更に付与するとratelimit.RateLimitExceptionが発生した際に次に実行できる時刻まで待機し実行できるようになったら再び関数を実行してくれます。

@sleep_and_retry
@limits(calls=3, period=10)
def call_api(url):

実行結果は下記のようになります。

0 2023-04-10T01:12:22.349554 200
1 2023-04-10T01:12:23.208253 200
2 2023-04-10T01:12:24.095663 200
3 2023-04-10T01:12:32.403814 200  # 1から10秒後
4 2023-04-10T01:12:33.279416 200
5 2023-04-10T01:12:34.135351 200
6 2023-04-10T01:12:42.400339 200  # 3から10秒後
7 2023-04-10T01:12:43.249893 200
8 2023-04-10T01:12:44.100634 200
9 2023-04-10T01:12:52.436109 200  # 6から10秒後

github.com

このデコレータを実装するとなると下記のようになります。
参照先のコードで動かなかったため、最後にcall_api()を実行した時刻を記録する変数をグローバル変数にし関数内でグローバル変数を書き換えられるように修正しました。

import time
from datetime import datetime
from functools import wraps
import requests

last_time_called = 0.0  # 最後に関数を実行した時刻を記録

def rate_limited(max_per_second):
    min_interval = 1.0 / float(max_per_second)  # 1回のcall間のインターバル

    def decorate(func):
        @wraps(func)
        def rate_limited_function(*args, **kargs):
            global last_time_called  # グローバル変数を書き換えるために読み込む
            elapsed = time.perf_counter() - last_time_called  # 現在の時刻 - 最後にcallを実行した時刻
            left_to_wait = min_interval - elapsed  # インターバル > 経過時間であれば(インターバル - 経過時間)待機
            if left_to_wait > 0:
                time.sleep(left_to_wait)
            ret = func(*args, **kargs)
            last_time_called = time.perf_counter()  # 最後に実行した時刻を記録
            return ret

        return rate_limited_function

    return decorate

@rate_limited(max_per_second=0.5)  # 1秒間に0.5回まで実行できる = 2秒間に1回実行できる
def call_api(url):
    response = requests.get(url)
    if response.status_code != 200:
        raise Exception("API response: {}".format(response.status_code))
    return response

url = "https://cat-fact.herokuapp.com/facts/"
# response = call_api(url)
# print(response.json())
for i in range(10):
    response = call_api(url)
    print(i, datetime.now().isoformat(), response.status_code)

この結果は下記のようになりました。

0 2023-04-10T01:26:53.290828 200
1 2023-04-10T01:26:56.139231 200
2 2023-04-10T01:26:59.011727 200
3 2023-04-10T01:27:01.872562 200
4 2023-04-10T01:27:04.811276 200
5 2023-04-10T01:27:07.660686 200
6 2023-04-10T01:27:10.554552 200
7 2023-04-10T01:27:13.416593 200
8 2023-04-10T01:27:16.312978 200
9 2023-04-10T01:27:19.136398 200

@rate_limited(max_per_second=0.5)をつけずに実行すると下記のようになります。
つけない場合は1秒ごとに実行されていることを考えると、関数の処理に1秒かかることがわかります。
上記のつけた場合の結果の出力では3秒ごとに実行されているため、実行の1秒分を引くと待機時間が2秒となっていることが分かります。

0 2023-04-10T01:26:26.350047 200
1 2023-04-10T01:26:27.241829 200
2 2023-04-10T01:26:28.103290 200
3 2023-04-10T01:26:29.008042 200
4 2023-04-10T01:26:29.856392 200
5 2023-04-10T01:26:30.710144 200
6 2023-04-10T01:26:31.601558 200
7 2023-04-10T01:26:32.456266 200
8 2023-04-10T01:26:33.340306 200
9 2023-04-10T01:26:34.221743 200

register

registerはプログラムが中断(Ctrl+C)された際に実行されるデコレータです。
下記の例だと次々にHelloが出力されますが、Ctrl+Cを押すとGoodbye!が出力され終了します。

from atexit import register

@register
def terminate():
    print("Goodbye!")

while True:
    print("Hello")

singledispatch

singledispathをつけた関数の引数を型ごとに分けることができます。
C#のコンストラクタのように引数の型によって何を実行するのかを選んでくれるようになります。

from functools import singledispatch

@singledispatch
def fun(arg):
    print("Called with a single argument")

# argの型がintだったらコレが実行される
@fun.register(int)
def _(arg):
    print("Called with an integer")

# argの型がlistだったらコレが実行される
@fun.register(list)
def _(arg):
    print("Called with a list")

fun(1)  # Prints "Called with an integer"
fun([1, 2, 3])  # Prints "Called with a list"

docs.python.org