基於令牌桶演算法實現一個限流器

zzzggb發表於2024-10-26

序言:本文章基於令牌桶演算法實現了簡單的一個限流器

1 令牌桶演算法

實現原理

  • 令牌生成:在固定的時間間隔內,演算法會向一個桶中放入一定數量的令牌。令牌的生成速率是固定的,通常以每秒鐘生成的令牌數來表示。
  • 桶的容量:桶有一個最大容量,如果桶滿了,新的令牌將被丟棄。這意味著即使在高流量情況下,系統也不會無限制地增加請求。
  • 請求處理:每當一個請求到達時,它需要從桶中獲取一個令牌。如果桶中有令牌,請求就可以被處理,桶中的令牌數量減一。如果沒有令牌,請求將被拒絕或被延遲,具體取決於實現。
  • 流量控制:透過調整令牌的生成速率和桶的容量,可以控制流量的平穩性和最大流量。

演算法優點:

  • 可以承載一定的突發流量情況。
  • 限流視窗的變化相對平穩。

coding

根據令牌桶演算法原理,可以先定義出三個變數。桶容量令牌產生速率當前桶中的令牌數量。同時定義一個rateLimiter類和對應的構造方法:

public class RateLimiter {
    // 自己寫的日誌列印工具(線上程池的文章中有貼)
    static Logger log = new Logger(Logger.LogLevel.DEBUG, RateLimiter.class);
    // 桶容量
    private final int maxPermit;
    // 令牌產生速率 / 秒
    private final int rate;
    // 當前桶中的令牌數量(考慮到這個變數會多執行緒使用, 使用原子類來實現)
    private final AtomicInteger holdPermits;
    
    public RateLimiter(int maxPermit, int rate, int initPermits) {
        if (rate < 1) throw new IllegalArgumentException("the rate must be greater than 1");
        if (initPermits < 0) throw new IllegalArgumentException("the initPermits must be greater than 0");
        if (maxPermit < 1) throw new IllegalArgumentException("the maxPermit must be greater than 1");
        this.maxPermit = maxPermit;
        this.rate = rate;
        this.holdPermits = new AtomicInteger(initPermits);
    }
}

然後我們需要給這個類新增一個 boolean tryAcquire(int permit) 方法。表示獲取permit數量個令牌。如果獲取成功返回true,獲取失敗則返回false。
可以寫出這個方法:

/**
 * 嘗試獲取 permit 數量的令牌
 *
 * @param permit 令牌數量
 * @return 獲取到 permit 數量的令牌則返回 true, 否則返回 false
 */
public boolean tryAcquire(int permit) {
    if (permit > maxPermit) throw new IllegalArgumentException("the permit must be smaller than maxPermit:" + maxPermit);
    if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
    int currentPermits;
    do {
        currentPermits = holdPermits.get();
        if (currentPermits < permit) {
            return false;
        }
    } while (!holdPermits.compareAndSet(currentPermits, currentPermits - permit));
    // 日誌列印
    log.debug("原令牌數:" + currentPermits + ", 減少:" + permit + ", 當前總數:" + (currentPermits - permit));
    return true;
}

這個方法中藉助了原子類的compareAndSet操作和自旋來實現令牌的扣減。噹噹前桶中的令牌數量大於等於請求獲取的令牌數時,使用compareAndSet來實現令牌的扣減。
但是這個方法可能由於其他執行緒的併發執行(提前扣減了令牌)而失敗。所以需要自旋操作保證令牌數量足夠時可以正確獲得令牌許可。只有當桶中的令牌數小於請求要求的令牌數量時才會返回失敗。

至此,令牌桶的構造方法獲取令牌的方法已經實現完成。但是何時且如何向桶中放入令牌呢?
如果使用定時任務的話,那麼就需要建立額外的執行緒物件來完成。

這裡借鑑前人的智慧,在每一次獲取令牌時順便計算和更新令牌數量,這樣的話,我們還需要一個變數記住上一次計算令牌的時間。
所以在類中加一個變數 lastFreshTime 記錄上一次計算更新令牌的時間,同時由於這個變數可能被多個執行緒更改,使用原子類物件保證執行緒安全。

// 上次計算更新令牌的時間
private final AtomicLong lastFreshTime;

然後,我們還需要一個方法,用於每次獲取令牌前計算更新桶中的令牌數量,這個方法中首先根據過去時間的納秒數量計算出應該產生的令牌數量,使用int向下取整,
然後使用令牌數量反向計算產生這些令牌所需的準確時間(因為令牌數量使用int取整了),加上當前lastFreshTime的值即可以得到新的lastFreshTime的值。
這裡為了保證執行緒安全,lastFreshTime 的更新使用compareAndSet保證只有一個執行緒可以獲取更新許可權。這個執行緒在成功更新lastFreshTime後,需要繼續更新令牌的數量,
由於在tryAcquire中還可能出現其他執行緒扣減令牌數量的行為,所以這裡還需要保證更新操作的原子性。

/**
 * 重新整理令牌數量
 */
private void freshPermit() {
    long now = System.nanoTime();
    long lastTime = lastFreshTime.get();
    if (now <= lastTime) return;
    int increment = (int) ((now - lastTime) * rate / 1_000_000_000);
    long thisTime = lastTime + increment * 1_000_000_000L / rate;
    if (increment > 0 && lastFreshTime.compareAndSet(lastTime, thisTime)) {
        int current;
        int next;
        do {
            current = holdPermits.get();
            next = Math.min(maxPermit, current + increment);
        } while (!holdPermits.compareAndSet(current, next));
        log.debug("原令牌數:" + current + ", 增加:" + increment + ", 當前總數:" + holdPermits.get());
    }
}

至此,我們就實現了一個簡單的令牌桶實現程式碼,只要每次tryAcquire時先使用freshPermit更新一下令牌數量就可以了。

但通常來說,令牌桶還會有一個帶超時時間的boolean tryAcquire(int permit, long timeOut)方法。這裡我們做一個簡單的實現,使用定期的sleep操作而不是鎖機制來完成。

假設要獲取的令牌數量為 p,超時時間為 t,那麼在tryAcquire(p, t)方法中,如果當前令牌不足p的話,那麼執行緒將會sleep一定時間會再次嘗試獲取令牌,直到使用時間超過t仍未獲取成功才會返回失敗。

這裡有一個問題就是睡眠時間sleepDuration的確定,在這個程式碼中,sleepDuration的值為 (p / rate) * 1000 ms,且最小為10,最大為100。
實現程式碼:

/**
 * 在 timeout 時間內嘗試獲取 permit 數量的令牌
 *
 * @param permit  令牌數量
 * @param timeOut 超時時間 單位 秒
 * @return 如果在 timout 時間內獲取到 permit 數量的令牌則返回 true, 否則返回 false
 */
public boolean tryAcquire(int permit, long timeOut) {
    if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
    if (timeOut < 0) throw new IllegalArgumentException("the timeOut must be greater than 0");
    timeOut = timeOut * 1_000_000_000 + System.nanoTime();
    long sleepDuration = (long) (1000.0 * permit / rate);
    sleepDuration = Math.min(sleepDuration, 100);
    sleepDuration = Math.max(sleepDuration, 10);
    while (System.nanoTime() < timeOut) {
        if (tryAcquire(permit)) return true;
        else {
            try {
                Thread.sleep(sleepDuration);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                return false;
            }
        }
    }
    return false;
}

至此,我們就實現了一個簡單的使用令牌桶演算法的限流器類。
完整程式碼:

public class RateLimiter {

    static Logger log = new Logger(Logger.LogLevel.DEBUG, RateLimiter.class);
    /**
     * 最大令牌數量
     */
    private final int maxPermit;
    /**
     * 令牌產生速率 / 每秒
     */
    private final int rate;

    /**
     * 當前可用令牌數量
     */
    private final AtomicInteger holdPermits;
    /**
     * 上次計算令牌的時間
     */
    private final AtomicLong lastFreshTime;


    public RateLimiter(int maxPermit, int rate, int initPermits) {
        if (rate < 1) throw new IllegalArgumentException("the rate must be greater than 1");
        if (initPermits < 0) throw new IllegalArgumentException("the initPermits must be greater than 0");
        if (maxPermit < 1) throw new IllegalArgumentException("the maxPermit must be greater than 1");
        if (maxPermit < rate) throw new IllegalArgumentException("the maxPermit must be greater than rate");
        this.maxPermit = maxPermit;
        this.rate = rate;
        this.holdPermits = new AtomicInteger(initPermits);
        this.lastFreshTime = new AtomicLong(System.nanoTime());
    }

    /**
     * 嘗試獲取 permit 數量的令牌
     *
     * @param permit 令牌數量
     * @return 獲取到 permit 數量的令牌則返回 true, 否則返回 false
     */
    public boolean tryAcquire(int permit) {
        if (permit > maxPermit)
            throw new IllegalArgumentException("the permit must be smaller than maxPermit:" + maxPermit);
        if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
        freshPermit();

        int currentPermits;
        do {
            currentPermits = holdPermits.get();
            if (currentPermits < permit) {
                return false;
            }
        } while (!holdPermits.compareAndSet(currentPermits, currentPermits - permit));

        log.debug("原令牌數:" + currentPermits + ", 減少:" + permit + ", 當前總數:" + (currentPermits - permit));
        return true;
    }

    /**
     * 重新整理令牌數量
     */
    private void freshPermit() {
        long now = System.nanoTime();
        long lastTime = lastFreshTime.get();
        if (now <= lastTime) return;
        int increment = (int) ((now - lastTime) * rate / 1_000_000_000);
        long thisTime = lastTime + increment * 1_000_000_000L / rate;
        if (increment > 0 && lastFreshTime.compareAndSet(lastTime, thisTime)) {
            int current;
            int next;
            do {
                current = holdPermits.get();
                next = Math.min(maxPermit, current + increment);
            } while (!holdPermits.compareAndSet(current, next));
            log.debug("原令牌數:" + current + ", 增加:" + increment + ", 當前總數:" + holdPermits.get());
        }
    }

    /**
     * 在 timeout 時間內嘗試獲取 permit 數量的令牌
     *
     * @param permit  令牌數量
     * @param timeOut 超時時間 單位 秒
     * @return 如果在 timout 時間內獲取到 permit 數量的令牌則返回 true, 否則返回 false
     */
    public boolean tryAcquire(int permit, long timeOut) {
        if (permit < 1) throw new IllegalArgumentException("the permit must be greater than 1");
        if (timeOut < 0) throw new IllegalArgumentException("the timeOut must be greater than 0");
        timeOut = timeOut * 1_000_000_000 + System.nanoTime();
        long sleepDuration = (long) (1.0 * permit / rate);
        sleepDuration = Math.min(sleepDuration, 100);
        sleepDuration = Math.max(sleepDuration, 10);
        while (System.nanoTime() < timeOut) {
            if (tryAcquire(permit)) return true;
            else {
                try {
                    Thread.sleep(sleepDuration);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    return false;
                }
            }
        }
        return false;
    }

    public static void main(String[] args) {
        RateLimiter rateLimiter = new RateLimiter(100, 50, 0);
        log.info("開始");
        for (int i = 0; i < 3; i++) {
            int j = i;
            new Thread(() -> {
                int k = 0;
                while (true) {
                    if (rateLimiter.tryAcquire(20, 1)) {
                        log.info("第" + k++ + "輪, 執行緒 " + j + " 獲取令牌成功");
                    } else log.error("第" + k++ + "輪, 執行緒 " + j + " 獲取令牌失敗");
                    try {
                        Thread.sleep(888);
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                }
            }).start();
        }
    }
}

總結

這個案例中其實還是可以學習到不少關於如何執行緒安全的實現功能的問題。

相關文章