利用 redis 實作 rate limit

最近剛好要處理 API 限流的問題,所以找了找相關的文章

後來找到了這篇 Rate Limiting Algorithms using Redis 提供了不少解法,這邊抄錄其中幾個方法,並轉換成 Python 給自己記憶一下。

固定時間區間 Fixed window

最基本的解法就是直接將所有的 request 以一個固定的時間區間去分隔並記錄,例如每小時的所有請求數量,然後時間到就直接重置。

import time
import redis
class FixedWindowCounter:
def __init__(self, db: redis.Redis, ttl_seconds: int) -> None:
assert isinstance(db, redis.Redis)
assert isinstance(ttl_seconds, int)
self.db = db
self.ttl_seconds = ttl_seconds
def get_current_bucket(self) -> str:
time_bucket = int(time.time()) // self.ttl_seconds
return f"ratelimit:{time_bucket}" # insert user id here
def increment(self, value: int) -> None:
key = self.get_current_bucket()
with self.db.pipeline() as pipe:
pipe.incr(key, value)
pipe.expire(key, self.ttl_seconds)
pipe.execute()
def get_count(self) -> int:
key = self.get_current_bucket()
cnt = self.db.get(key)
if cnt is None: # redis return None on key not found
return 0
return int(cnt)

優點

實作容易,使用的記憶體少

缺點

容易被閃掉,在接近時間區間尾端時開打的大量請求相對容易被分在兩個不同的計數器內

逐筆紀錄

在原文中有 Sliding Logs 及 Leaky Bucket 是逐次紀錄過去時間段內所有的請求數,分別是用 Sorted SetList 去紀錄時間內的總請求數量。這兩個方案記憶體量會稍高一點點,先跳過。

加權總和 Weighted sum

在上面的文章中,原作者稱這個方法是 Sliding window,但我覺得概念不太一樣,所以改個名稱記著。

這個方法是基於 Fixed window,但是以當下區間中的計數加上前一個區間的部分計數,按時間加權去估計是否已經超標。

class WeightedSumCounter:
def __init__(self, db: redis.Redis, ttl_seconds: int) -> None:
assert isinstance(db, redis.Redis)
assert isinstance(ttl_seconds, int)
self.db = db
self.ttl_seconds = ttl_seconds
def get_bucket_start_time(self, offset=0) -> int:
n_interval = int(time.time()) // self.ttl_seconds
timestamp = (n_interval + offset) * self.ttl_seconds
return timestamp
def get_bucket(self, offset=0) -> str:
return f"ratelimit:{self.get_bucket_start_time(offset)}" # insert user id here
def increment(self, value: int) -> None:
key = self.get_bucket()
with self.db.pipeline() as pipe:
pipe.incr(key, value)
pipe.expire(key, self.ttlSeconds)
pipe.execute()
def get_count(self) -> float:
cnt = self.db.get(self.get_bucket()) or 0
if cnt:
cnt = int(cnt)
cnt_last = self.db.get(self.get_bucket(-1))
if cnt_last:
elapsed_time = int(time.time()) - self.get_bucket_start_time()
ratio_last_bucket = 1 - (elapsed_time / self.ttl_seconds)
cnt += cnt_last * ratio_last_bucket
return cnt

優點

記憶體使用效率高,且可以解決大流量剛好發生在時間區間段切換的問題。

缺點

  1. 在特定的請求數量、時間分配下,依然可以發生超過限額的狀況
  2. 會發生 Traffic shaping,一定時間內能處理的請求數會被卡到,所以可能會造成使用者抱怨,但另一方面卻能預防伺服器過載——原文這麼提及了,但我的程度還沒去理解出為啥,但總之先記著

移動時間區間 Sliding window

既然固定時間區間有不精準的問題,逐次紀錄又消耗太多資源,那麼折衷的方式就是把時間區間分成多個小區間,然後把時間範圍內各小區間的計數加一加。

這個實作是部門前輩建議的,並利用到了 Redis 的 Hashes。這裡的鍵是每個小區間的起始時間,而值是該時間區間有多少請求數,user id 之類的額外區分就直接丟在 name 裡面了。

整理這個段落時思考到似乎用 Sorted set 可以有更好的效率,哪天想起來再來整理看看

class SlidingWindowCounter:
def __init__(self, db: redis.Redis, ttl_seconds: int, subinterval_seconds: int) -> None:
assert isinstance(db, redis.Redis)
assert isinstance(ttl_seconds, int)
assert isinstance(subinterval_seconds, int)
assert ttl_seconds > subinterval_seconds
self.db = db
self.ttl_seconds = ttl_seconds
self.subinterval_seconds = subinterval_seconds
def get_name(self) -> str:
return f"ratelimit:general" # insert user id here
def get_bucket_start_time(self, timestamp=None) -> int:
timestamp = timestamp or time.time()
timestamp = int(timestamp) // self.subinterval_seconds * self.subinterval_seconds
return timestamp
def increment(self, value: int) -> None:
name = self.get_name()
with self.db.pipeline() as pipe:
pipe.hincrby(name, self.get_bucket_start_time(), value)
pipe.expire(name, self.ttl_seconds)
pipe.execute()
def get_count(self) -> int:
name = self.get_name()
# drop expired buckets
base_time = self.get_bucket_timestamp()
expired_buckets = set()
for bucket in self.db.hkeys(name):
if base_time - int(bucket) > self.ttl_seconds:
expired_buckets.add(bucket)
if expired_buckets:
self.db.hdel(name, *expired_buckets)
# get value and sum
counts = self.db.hvals(name)
return sum(map(int, counts))

優點

記憶體使用效率高,且可以數到精確的請求總數。

缺點

操作比較複雜,可預想其反應速度比較慢。

一個可考慮的小改善是把 get_count() 中的過期檢查直接拔掉,則數字會稍微不精準一點點、但少掉兩個 transaction,其計數變成仰賴 Redis 本身的 TTL 達成。

Token bucket

原文中有出現的方法,我覺得很有趣,雖然不適用最近自己的需求。

前面有紀錄的方法中都是正向去數已經有多少請求,而我們也可以反過來數剩下多少額度。直觀上好像只差一個負號沒啥差,但搭配 Redis 內建的 TTL 卻有奇效——我們可以精確的知道什麼時候超過了限額、需要進行阻擋,而且操作又相對精簡。

class TokenBucketRateLimit:
def __init__(self, db: redis.Redis, ttl_seconds: int, max_count: int) -> None:
assert isinstance(db, redis.Redis)
assert isinstance(ttl_seconds, int)
assert isinstance(max_count, int)
self.db = db
self.ttl_seconds = ttl_seconds
self.max_count = max_count
def get_name(self) -> str:
return f"ratelimit:general" # insert user id here
def inc_and_check(self, value):
"""Increment the counter and check the threshold. Return True when limit
is exceeded.
"""
assert value < self.max_count
name = self.get_name()
remain = self.db.get(name)
if remain is None:
# never seen or expired
with self.db.pipeline() as pipe:
pipe.set(name, self.max_count - value)
pipe.expire(name, self.ttl_seconds)
pipe.execute()
elif remain < value:
# reach threshold
return True
else:
# counter not expired, not reach threshold
self.db.decr(name, value)
return False

優點

記憶體使用效率高,操作數量少

查文件時發現 Redis 7.0 有個 Expire option ,利用進來後似乎有機會再更減少一點操作?

缺點

  1. 跟加權總和一樣,特定條件下仍有機會超標,但這個方法因為起始的時間段是彈性的,所以超標的難度更高
  2. 彈性較低,這個方法無法在長時間、非大量的請求時用於得知其在各時段的精確計數