阿里 TransmittableThreadLocal 程式碼簡讀

三流 發表於 2022-06-08
MIT

零 準備

0 FBI WARNING

文章異常囉嗦且繞彎。

1 TransmittableThreadLocal 是什麼

當開發人員需要線上程池的執行緒中傳遞某些引數的時候,jdk 的 ThreadLocal 很難實現,靜態變數則會面臨不夠靈活和出現執行緒安全等問題。
TransmittableThreadLocal 是阿里開源工具包,用於解決這一問題。

2 版本

  • jdk 版本

Azul JDK 17.0.2

  • transmittable-thread-local

2.13.0-Beta1

  • junit-jupiter

5.8.2

一 Demo

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.threadpool.TtlExecutors;
import org.junit.jupiter.api.Test;

import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TreadLocalTest {

    @Test
    public void transmittableThreadLocal() {

        TransmittableThreadLocal<Integer> tl = new TransmittableThreadLocal<>();

        tl.set(6);
        System.out.println("父執行緒獲取資料:" + tl.get()); // 第一次輸出:6

        // 使用 jdk 的 Executors 工具建立一個執行緒池
        // 注意,這個執行緒池裡只有一個執行緒
        Executor realPool = Executors.newFixedThreadPool(1);
        
        // 使用 TtlExecutors 建立一個 Ttl 框架封裝的執行緒池
        Executor pool = TtlExecutors.getTtlExecutor(realPool);

        // 使用執行緒池跑一個任務
        pool.execute(() -> {
            Integer i = tl.get();
            System.out.println("第一次獲取資料:" + i); // 第二次輸出:6
        });

        // 修改一下 tl 裡的值,並再跑一次任務
        tl.set(7);
        pool.execute(() -> {
            Integer i = tl.get();
            System.out.println("第二次獲取資料:" + i); // 第三次輸出:7
        });
    }
}

二 先從 InheritableThreadLocal 說起

1 Thread

InheritableThreadLocal 是 jdk 中自帶的 ThreadLocal 的子類,在 jdk 的 Thread 物件中,會對它有單獨的支援。
首先來看 Thread 的構造方法:

// java.lang.Thread 的核心構造方法
private Thread(ThreadGroup g, Runnable target, String name,
                   long stackSize, AccessControlContext acc,
                   boolean inheritThreadLocals) {
    
    // 此處省略一大段無關程式碼...
    
    // inheritThreadLocals 是一個 boolean 型別的值,是一個 “是否啟用 inheritableThreadLocals” 的開關
    // parent 是創造此執行緒的父執行緒
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        // 如果父執行緒的 inheritableThreadLocals 存在,則此處會將它挪到當前執行緒裡
        // ThreadLocal.createInheritedMap 是一個深拷貝,會建立新的 Entry
        this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    
    // 此處省略一大段無關程式碼...
}

2 ThreadLocalMap

來看一下 ThreadLocal.createInheritedMap:

// java.lang.ThreadLocal
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}

這個方法會建立一個 ThreadLocalMap,再來追蹤一下 ThreadLocalMap 的構造器:
(值得注意的是,ThreadLocalMap 是 ThreadLocal 的內部類,所以其實程式碼邏輯還是在 ThreadLocal.java 中)

// java.lang.ThreadLocal
private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];

    // 此處把 ThreadLocalMap 裡的元素都遍厲一遍
    // 然後都建立成新的 Entry 並塞到新的 ThreadLocalMap 裡
    for (Entry e : parentTable) {
        if (e != null) {
            // 此處獲取了 Entry 的 key,本質上就是 ThreadLocal 本身
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                // 用 key 取獲取 value,這行程式碼重點關注,下文會提到
                Object value = key.childValue(e.value);
                // 此處建立新的 Entry
                Entry c = new Entry(key, value);
                // 處理 hash 碰撞問題並存入
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

3 childValue

這裡需要重點關注一行程式碼:

Object value = key.childValue(e.value);

這個方法是 ThreadLocal 中的:

// java.lang.ThreadLocal
T childValue(T parentValue) {
    throw new UnsupportedOperationException();
}

由上文可見,這是一個沒有被實現的預留模板方法。在 InheritableThreadLocal 中對其進行了實現:

// java.lang.InheritableThreadLocal
protected T childValue(T parentValue) {
    return parentValue;
}

5 initialValue

initialValue 同樣是 ThreadLocal 提供的一個空方法:

// java.lang.ThreadLocal
protected T initialValue() {
    return null;
}

這個方法會作用在 ThreadLocal 的 get() 方法裡:

// step 1
// java.lang.ThreadLocal
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            // 如果 Entry 存在,則此處會返回 Entry 的 value
            T result = (T)e.value;
            return result;
        }
    }
    // 如果 Entry 不存在,或者 ThreadLocalMap 不存在,會在這裡初始化一個 value
    // 這個方法見 step 2
    return setInitialValue();
}

// step 2
// java.lang.ThreadLocal
private T setInitialValue() {
    // 這裡初始化一個值
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 將初始化出來的值存進去
        map.set(this, value);
    } else {
        // 初始化 ThreadLocalMap
        createMap(t, value);
    }
    
    // 此處忽略這段程式碼
    if (this instanceof TerminatingThreadLocal) {
        TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
    }
    
    // 返回
    return value;
}

5 InheritableThreadLocal 的作用和問題

假設 Thread A 是 Thread B 的父執行緒,由上述程式碼可知:

  • A 的 InheritableThreadLocal 內的資料可以被 B 繼承
  • 繼承方式是在建立 B 的時候,在構造方法裡直接 copy 一份 InheritableThreadLocal 內的元素
  • copy 是一個快照機制,一旦結束,再去修改 A 中的 InheritableThreadLocal 中的元素,就不會同步給 B 了

那麼問題來了:
如果系統中需要做到 A 和 B 的 InheritableThreadLocal 實時同步,應該如何解決?

三 TransmittableThreadLocal

先來看下列三行程式碼:

// 建立一個 TransmittableThreadLocal
ThreadLocal<Integer> tl = new TransmittableThreadLocal<>();

tl.set(6);

Integer i = tl.get();

1 構造器

TransmittableThreadLocal 的構造器非常簡單。

// 是否要忽略 null value,如果這個引數為 false,則哪怕 value 是 null,也會儲存下來
private final boolean disableIgnoreNullValueSemantics;

// 這個引數預設為 false
public TransmittableThreadLocal() {
    this(false);
}


public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {
    this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;
}

2 holder

holder 是 TransmittableThreadLocal 的靜態成員變數,是一個 InheritableThreadLocal。

// com.alibaba.ttl.TransmittableThreadLocal
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
    new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
    
    // 複寫這個方法應該沒有別的深意,只是為了防止在呼叫 holder.get().xxx() 的時候報空指標
    // 應該是開發人員覺得這樣比較優雅
    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
        return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    }

    // 這個方法實現了子執行緒和父執行緒之間的資訊傳遞
    @Override
    protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
        return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
    }
};

由上述可知:

  • holder 是一個記錄的 value 是 WeakHashMap<TransmittableThreadLocal> 的 InheritableThreadLocal
  • WeakHashMap 的 value 並沒有被使用到,可以將其視為一個 WeakHashSet
  • holder 複寫了 initialValue 和 childValue 兩個方法

holder 最重要的方法是 addThisToHolder:

// com.alibaba.ttl.TransmittableThreadLocal
// 如果當前 TransmittableThreadLocal 沒有被記錄在 holder 中,則會在此處 put 進去
private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}

同樣還有移除方法:

// com.alibaba.ttl.TransmittableThreadLocal
private void removeThisFromHolder() {
    holder.get().remove(this);
}

3 set

存入 value 的方法。

// com.alibaba.ttl.TransmittableThreadLocal
@Override
public final void set(T value) {
    if (!disableIgnoreNullValueSemantics && null == value) {
        // 如果 value 是 null,且不忽略 null value,則此處進入刪除邏輯
        remove();
    } else {
        // 儲存邏輯
        super.set(value);
        // 將當前的 TransmittableThreadLocal 註冊到 holder 裡
        addThisToHolder();
    }
}

4 get

獲取 value 的方法。

// com.alibaba.ttl.TransmittableThreadLocal
@Override
public final T get() {
    T value = super.get();
    // 嘗試註冊到 holder
    if (disableIgnoreNullValueSemantics || null != value) 
        addThisToHolder();
    return value;
}

5 Snapshot

Snapshot 是 TransmittableThreadLocal 的內部類,用來存放當前執行緒內的 ThreadLocal 和 TransmittableThreadLocal 資料。

// com.alibaba.ttl.TransmittableThreadLocal
private static class Snapshot {
    final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;

    private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
        this.ttl2Value = ttl2Value;
        this.threadLocal2Value = threadLocal2Value;
    }
}

6 Transmitter

Transmitter 是 TransmittableThreadLocal 的內部類,本質上是一組靜態工具。

6.1 獲取一個快照

// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
public static Object capture() {
    // captureTtlValues()  會將當前執行緒的 TransmittableThreadLocal 資料做成一個 HashMap
    // captureThreadLocalValues() 會將當前執行緒的 ThreadLocal 資料做成一個 HashMap
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
6.1.1 獲取 holder 中所有的 TransmittableThreadLocal 資料
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    
    HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
    
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    
    return ttl2Value;
}
6.1.2 獲取 threadLocalHolder 中所有 ThreadLocal 資料
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
    
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
    
    for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        final TtlCopier<Object> copier = entry.getValue();

        threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
    }
    
    return threadLocal2Value;
}

6.2 重放

6.2.1 replay
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本質上是對一個 snapshot 進行拷貝
public static Object replay(Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
6.2.2 replayTtlValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本質上是對一個 map 進行深拷貝
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    
    // 建立一個新的 map
    HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // 將原來的 map 複製到新的 map 中
        backup.put(threadLocal, threadLocal.get());

        // 此處比較 holder 和 captured 的 key
        // 如果對應不一致,則將 holder 裡的資料清空
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 將 value 和 key 對應起來
    // 這是一個保底糾錯邏輯
    setTtlValuesTo(captured);

    // 這是一個暫時沒有用的擴充套件方法
    doExecuteCallback(true);

    return backup;
}
6.2.3 replayThreadLocalValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 本質上是對一個 map 進行深拷貝
private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(HashMap<ThreadLocal<Object>, Object> captured) {
    final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<ThreadLocal<Object>, Object>();

    for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        backup.put(threadLocal, threadLocal.get());

        // threadLocalClearMark 是一個空物件,用於佔位
        // 如果此處的 value 就是這個空物件,則此處代表這個 ttl 裡的 value 已經被 clear 了
        final Object value = entry.getValue();
        if (value == threadLocalClearMark) 
            threadLocal.remove();
        else 
            threadLocal.set(value);
    }

    return backup;
}

6.3 恢復

6.3.1 restore
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
// 用快照來恢復當前執行緒的 ttl 資料
public static void restore(Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}
6.3.2 restoreTtlValues

這個方法與 replayTtlValues(...) 方法比較像

// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static void restoreTtlValues(HashMap<TransmittableThreadLocal<Object>, Object> backup) {
    doExecuteCallback(false);

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        if (!backup.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    setTtlValuesTo(backup);
}
6.3.3 restoreThreadLocalValues
// com.alibaba.ttl.TransmittableThreadLocal.Transmitter
private static void restoreThreadLocalValues(HashMap<ThreadLocal<Object>, Object> backup) {
    for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

四 ExecutorTtlWrapper

1 ExecutorTtlWrapper

ExecutorTtlWrapper 的程式碼非常少:

// com.alibaba.ttl.threadpool.ExecutorTtlWrapper
class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    
    // 這個變數代表了一個執行緒池
    private final Executor executor;
    // 這個變數是一個冪等識別符號
    protected final boolean idempotent;

    ExecutorTtlWrapper(Executor executor, boolean idempotent) {
        this.executor = executor;
        this.idempotent = idempotent;
    }

    @Override
    public void execute(Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }

    @Overrid
    public Executor unwrap() {
        return executor;
    }

    // 其它方法不重要,這裡省略...
}

ExecutorTtlWrapper 本質上是一個執行緒池的代理,在執行 execute(...) 方法的時候,會將 Runnable 任務包裝成 TtlRunnable。

2 TtlEnhanced

// 這是一個單純的空介面,用來標識一個類
public interface TtlEnhanced {
    
}

3 TtlWrapper

// TtlWrapper 用來標識一個包裝類
// 需要實現獲取被包裝物件的 unwrap 方法
public interface TtlWrapper<T> extends TtlEnhanced {
    T unwrap();
}

4 TtlExecutors

TtlExecutors 是一個靜態工具類,用來生成 ExecutorTtlWrapper。

// com.alibaba.ttl.threadpool.TtlExecutors
public static Executor getTtlExecutor(Executor executor) {
    // 如果已經包裝過了,那麼此處直接返回
    if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
        return executor;
    }
    
    // 如果沒有包裝過,那麼此處包裝一下
    // 冪等識別符號,此處預設為 true
    return new ExecutorTtlWrapper(executor, true);
}

TtlAgent 是對探針技術的應用,暫時不展開講解。

五 TtlRunnable

1 TtlRunnable

首先來看一下 class:

// com.alibaba.ttl.TtlRunnable
public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    
    private final AtomicReference<Object> capturedRef;
    private final Runnable runnable;
    private final boolean releaseTtlValueReferenceAfterRun;
    
    private TtlRunnable(Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // capture() 方法見上面 第三 part 的 Transmitter 部分
        // 本質上這是當前執行緒所儲存的 TransmittableThreadLocal 和 ThreadLocal 的快照
        this.capturedRef = new AtomicReference<Object>(capture());
        // 真實的業務邏輯
        this.runnable = runnable;
        // 當前 TtlRunnable 是否可以重複執行
        // true 的情況下,只要執行完,就不能重複執行了
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }
    
    // 其它方法先省略...   
}

2 get

TtlRunnable.get(...) 是一個靜態方法,用於建立一個 TtlRunnable 物件。

// com.alibaba.ttl.TtlRunnable
public static TtlRunnable get(Runnable runnable) {
    return get(runnable, false, false);
}

public static TtlRunnable get(Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
    return get(runnable, releaseTtlValueReferenceAfterRun, false);
}

public static TtlRunnable get(Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
    // 空判斷
    if (null == runnable) 
        return null;

    // 如果當前為冪等,則此處複用
    if (runnable instanceof TtlEnhanced) {
        if (idempotent) 
            return (TtlRunnable) runnable;
        else 
            throw new IllegalStateException("Already TtlRunnable!");
    }
    
    // 建立物件
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

3 run

TtlRunnable.run() 是核心方法,是對業務邏輯的封裝。

// com.alibaba.ttl.TtlRunnable
public void run() {
    
    // 獲取當前快照
    final Object captured = capturedRef.get();
    
    // 有效性判斷
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }

    // replay 方法來自 Transmitter
    // 用於建立一個當前執行緒的 ThreadLocal 的備份
    final Object backup = replay(captured);
    try {
        runnable.run();
    } finally {
        // restore 方法來自 Transmitter
        // 使用備份來恢復當前執行緒的 ThreadLocal 資料
        restore(backup);
    }
}

captured 實際上是一個備忘錄模式,用於確保子執行緒內的資料修改不影響到父執行緒。

六 一點嘮叨

  • 封裝的很有意思,但是很多細節還是沒太看懂
  • 僅為個人的學習筆記,可能存在錯誤或者表述不清的地方,有緣補充