併發程式設計之 ThreadLocal 原始碼剖析

莫那·魯道發表於2018-04-30

併發程式設計之 ThreadLocal 原始碼剖析

前言

首先看看 JDK 文件的描述:

該類提供了執行緒區域性 (thread-local) 變數。這些變數不同於它們的普通對應物,因為訪問某個變數(通過其 get 或 set 方法)的每個執行緒都有自己的區域性變數,它獨立於變數的初始化副本。ThreadLocal 例項通常是類中的 private static 欄位,它們希望將狀態與某一個執行緒(例如,使用者 ID 或事務 ID)相關聯。

每個執行緒都保持對其執行緒區域性變數副本的隱式引用,只要執行緒是活動的並且 ThreadLocal 例項是可訪問的;線上程消失之後,其執行緒區域性例項的所有副本都會被垃圾回收(除非存在對這些副本的其他引用)

例如,以下類生成對每個執行緒唯一的區域性識別符號。 執行緒 ID 是在第一次呼叫 UniqueThreadIdGenerator.getCur

java.lang.ThreadLocal 不是 1.5 新加入的類,在 1.2 的時候就已經存在 java 類庫的類,但該類的作用非常的大,所以我們也要剖析一下他的原始碼,也要驗證關於該類的一些爭論,比如記憶體洩漏。

1. 如何使用?

該類有4個方法方法需要關注:

public class ThreadLocal<T> {

    public T get();

    public void set(T value);

    public void remove();

    protected T initialValue();
    
複製程式碼

get() : 返回此執行緒區域性變數的當前執行緒副本中的值。如果變數沒有用於當前執行緒的值,則先將其初始化為呼叫  initialValue() 方法返回的值。

set():將此執行緒區域性變數的當前執行緒副本中的值設定為指定值。大部分子類不需要重寫此方法,它們只依靠 initialValue() 方法來設定執行緒區域性變數的值。

remove() :移除此執行緒區域性變數當前執行緒的值。如果此執行緒區域性變數隨後被當前執行緒讀取, 且這期間當前執行緒沒有設定其值,則將呼叫其 initialValue() 方法重新初始化其值。這將導致在當前執行緒多次呼叫 initialValue 方法。則不會對該執行緒再呼叫 initialValue 方法。通常,此方法對每個執行緒最多呼叫一次,但如果在呼叫 get() 後又呼叫了 remove() ,則可能再次呼叫此方法。

initialValue():返回此執行緒區域性變數的當前執行緒的“初始值”。執行緒第一次使用 get() 方法變數時將呼叫此方法,但如果執行緒之前呼叫了 set(T) 方法,

我們還是來個例子:

package cn.think.in.java.lock.tools;

public class ThreadLocalDemo {

  public static void main(String[] args) {
    ThreadLocal<String> local = new MyThreadLocal<>();
    System.out.println(local.get());
    local.set("hello");
    System.out.println(local.get());
    local.remove();
    System.out.println(local.get());
  }

}

class MyThreadLocal<T> extends ThreadLocal<T> {

  @Override
  protected T initialValue() {
    return (T) "world";
  }
}


複製程式碼

執行結果

world
hello
world
複製程式碼

上面的程式碼中,我們重寫了 ThreadLocal 的 initialValue 方法,返回了一個字串 “world”,第一次呼叫 get 方法返回了該值,而我們然後又呼叫 set 方法設定了 hello 字串,再次呼叫 get 方法,此時返回的就是剛剛set 的值 ---- hello,然後我們呼叫remove 方法,刪除 hello,再次呼叫 get 方法,返回了 initialValue 方法中的 world。

從這個流程中,我們已經知道了該類的用法,那麼我們就看看原始碼是如何實現的。

get() 原始碼剖析

    public T get() {
        // 獲取當前執行緒
        Thread t = Thread.currentThread();
        // 獲取當前執行緒的 ThreadLocalMap  物件
        ThreadLocalMap map = getMap(t);
        // 如果map不是null,將 ThreadlLocal 物件作為 key 獲取對應的值
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            // 如果該值存在,則返回該值
            if (e != null) {
                T result = (T)e.value;
                return result;
            }
        }
        // 如果上面的邏輯沒有取到值,則從 initialValue  方法中取值
        return setInitialValue();
    }
複製程式碼

樓主在該方法中寫了註釋,主要邏輯是從 當前執行緒中取出 一個類似 Map 的物件, map 中 key是 ThreadLocal 物件,value 則是我們設定的值。如果 該 map中沒有,則從 initialValue 方法中取。

我們繼續看看,map 的真實面目:

併發程式設計之 ThreadLocal 原始碼剖析

併發程式設計之 ThreadLocal 原始碼剖析

併發程式設計之 ThreadLocal 原始碼剖析

就是這個map,這個map 持有一個 Entry 陣列,Entry 繼承了 WeakReference ,也就是弱引用,如果一個物件具有弱引用,在GC執行緒掃描記憶體區域的過程中,不管當前記憶體空間足夠與否,都會回收記憶體。這個特性我們之後再說。

總的來說,每個執行緒物件中都有一個 ThreadLocalMap 屬性,該屬性儲存 ThreadLocal 為 key ,值則是我們呼叫 ThreadLocal 的 set 方法設定的,也就是說,一個ThreakLocal 物件對應一個 value。

還沒完,我們看看 getEntry 方法:

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
複製程式碼

如果hash 沒有衝突,直接返回 對應的值,如果衝突了,呼叫 getEntryAfterMiss 方法。

getEntryAfterMiss 原始碼:

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

複製程式碼

該方法會迴圈所有的元素,直到找到 key 對應的 entry,如果發現了某個元素的 key 是 null,順手呼叫 expungeStaleEntry 方法清理 所有 key 為 null 的 entry。

那麼 set 方法是怎麼樣的呢?

3. set() 方法原始碼剖析

原始碼如下:

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
複製程式碼

該方法同樣先得到當前執行緒,然後根據當前執行緒得到執行緒的 ThreadLocalMap 屬性,如果 Map 為null, 則建立一個Map ,並將值放置到Map中,否則,直接將值放置到Map中。

先看看 createMap(Thread t, T firstValue) 方法:

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY]; // 預設長度16
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); // 得到下標
            table[i] = new Entry(firstKey, firstValue); // 建立一個entry物件並插入陣列
            size = 1; // 設定長度屬性為1
            setThreshold(INITIAL_CAPACITY); 設定閥值== 16 * 2 / 3 == 10
        }

複製程式碼

這個方法很簡單,樓主已經寫了詳細的註釋。就是建立一個16長度的entry 陣列,設定閥值為10,注意,再resize 的時候,並不是10,而是 10 - 10 / 4,也就是 8,負載因子為 0.5,和 HashMap 是不同的。

我們再看看 map.set(ThreadLocal<?> key, Object value) 方法如何實現的:

       private void set(ThreadLocal<?> key, Object value) {
            
            Entry[] tab = table;
            int len = tab.length;
            // 根據 ThreadLocal 的 HashCode 得到對應的下標
            int i = key.threadLocalHashCode & (len-1);
            // 首先通過下標找對應的entry物件,如果沒有,則建立一個新的 entry物件
            // 如果找到了,但key衝突了或者key是null,則將下標加一(加一後如果小於陣列長度則使用該值,否則使用0),
            // 再次嘗試獲取對應的 entry,如果不為null,則在迴圈中繼續判斷key 是否重複或者k是否是null
            for (Entry e = tab[i]; e != null;  e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                // key 相同,則覆蓋 value
                if (k == key) {
                    e.value = value;
                    return;
                }
                // 如果key被 GC 回收了(因為是軟引用),則建立一個新的 entry 物件填充該槽
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            
            // 建立一個新的 entry 物件
            tab[i] = new Entry(key, value);
            // 長度加一
            int sz = ++size;
            // 如果沒有清楚多餘的entry 並且陣列長度達到了閥值,則擴容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
複製程式碼

這裡樓主剛開始有一個奇怪的地方,為什麼這裡和 HashMap 處理 Hash 衝突的方式不一樣,樓主後來查詢資料,才明白,HashMap 的Hash衝突方法是拉鍊法,即用連結串列來處理,而 ThreadLocalMap 處理Hash衝突採用的是線性探測法,即這個槽不行,就換下一個槽,直到插入為止。但是該方法有一個問題,就是,如果整個陣列都衝突了,就會不停的迴圈,導致死迴圈,雖然這種機率很小。

我們繼續往下。

如果 k == null,表示 ThreadLocal 被GC回收了,那麼就呼叫 replaceStaleEntry 方法重新生成一個 entry,不過該方法沒有我說的那麼簡單,我們來看看:

        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
複製程式碼

該方法可以說有點複雜,樓主看了很久,真的沒想到 ThreadLocal 這麼複雜。。。。。如同該方法名稱,該方法會刪除陳舊的 entyr,什麼是陳舊的呢,就是 ThreadLocal 為 null 的 entry,會將 entry key 為 null 的物件設定為null。核心的方法就是 expungeStaleEntry(int);

整體邏輯就是,通過線性探測法,找到每個槽位,如果該槽位的key為相同,就替換這個value;如果這個key 是null,則將原來的entry 設定為null,並重新建立一個entry。

不論如何,只要走到了這裡,都會清除所有的 key 為null 的entry,也就是說,當hash 衝突的時候並且對應的槽位的key值是null,就會清除所有的key 為null 的entry。

我們回到 set 方法。如果 hash 沒有衝突,也會呼叫 cleanSomeSlots 方法,該方法同樣會清除無用的 entry,也就是 key 為null 的節點。我們看看程式碼:

      private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }
複製程式碼

該方法會遍歷所有的entry,並判斷他們的key,如果key是null,則呼叫 expungeStaleEntry 方法,也就是清除 entry。最後返回 true。

如果返回了 false ,說明沒有清除,並且 size 還 大於等於 10 ,就需要 rahash,該方法如下:

       private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }
複製程式碼

首先會呼叫 expungeStaleEntries 方法,該方法會清除無用的 entry,我們之前說過了,同時,也會對 size 變數做減法,如果減完之後,size 還大於 8,則呼叫 resize 方法做真正的擴容。

resize 方法如下:

        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }
複製程式碼

該方法會直接擴容為原來的2倍,並將老陣列的資料都移動到 新陣列,size 變數記錄了裡面有多少資料,最後設定擴容閥值為 2/3。

所以說,擴容分為2個步驟,當長度達到了容量的2/3,就會清理無用的資料,如果清理完之後,長度還大於等於閥值的3/4,那麼就做真正的擴容。而不是網上很多人說的達到了 2/3 就擴容。這裡的誤區就是擴容之前需要清理。清理完之後再做判斷。

可以看到,每次呼叫set 方法都會進行清理工作。實際上,如果使用 get 方法,當對應的 entry 的key為null 的時候,也會進行清理。

4. remove 方法原始碼剖析

        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
複製程式碼

通過線性探測法找到 key 對應的 entry,呼叫 clear 方法,將 ThreadLocal 設定為null,呼叫 expungeStaleEntry 方法,該方法順便會清理所有的 key 為 null 的 entry。

5. Thread 執行緒退出時清理 ThreadLocal

Thread 的exit 方法:

    private void exit() {
        if (group != null) {
            group.threadTerminated(this);
            group = null;
        }
        /* Aggressively null out all reference fields: see bug 4006245 */
        target = null;
        /* Speed the release of some of these resources */
        threadLocals = null;
        inheritableThreadLocals = null;
        inheritedAccessControlContext = null;
        blocker = null;
        uncaughtExceptionHandler = null;
    }
複製程式碼

可以看到,該方法會將執行緒相關的所有屬性變數全部清除。包括 threadLocals。

總結

樓主開始以為這個類的程式碼不會很難,想來樓主太天真了。從原始碼中我們可以看到,ThreadLocal 類的作者無時無刻都在想著如何去除那些 key 為 null 的 元素,為什麼?因為只要執行緒不退出,這些變數都會一直留線上程中。

但是,Java 中有執行緒池的技術,也就是說,執行緒基本不會退出,因此,就需要手動去刪除這些變數。如果你線上程中放置了一個大大的物件,使用完幾次後沒有清除(呼叫 remove 方法),該物件將會一直留線上程中。造成了記憶體洩漏。

為什麼要使用弱引用呢?我們假設一下,不用弱引用,如果我們使用的 ThreadLocal 的變數是個區域性變數,並設定到了執行緒中,當這個方法結束時,我們沒有呼叫 remove 方法,而 Map 中 key 不是弱引用,那麼該變數將會一直存在!!!

如果使用了弱引用,就算你沒有呼叫 remove 方法,GC 也會清除掉 Map 中的引用,同時,ThreadLocal 也會通過對 key 是否為 null 進行判斷,從而防止記憶體洩漏。

這裡我們重新總結一下:ThreadLocal 的作者之所以使用弱引用,是擔心程式設計師使用了區域性變數的ThreadLocal 並且沒有呼叫 remove 方法,這將導致沒有結束的執行緒發生記憶體洩漏。使用弱引用,即使程式設計師沒有刪除,GC 也會將該變數設定為null,ThrealLocal 通過判斷 key 是否為 null 來清除無用資料。防止記憶體洩漏。

當然,如果你使用的是靜態變數,並且使用結束後沒有設定為 null, ThrealLocal 是無法自動刪除的,因此需要呼叫 remove 方法。

那麼,ThrealLocal 什麼時候會自動回收呢?當呼叫 remove 方法的時候(廢話),當呼叫 get 方法並且 hash 衝突了的時候(情況很少),呼叫 set 方法時 hash 衝突了,呼叫 set 方法時正常插入。注意,呼叫 set 方法時,如果是覆蓋操作,則不會執行清理。

我們正常使用 ThreadLocal 都是靜態變數,也是 JDK 建議的例子,所以一定要手動呼叫 remove 方法,或者使用完畢後置為 null。反之,你可以碰運氣不好,JDK 可能會幫你刪,比如在你 set 的時候(也就是我們上面說的那幾種情況),如果運氣不好,就會永遠存線上程中,導致記憶體洩漏。

所以,強烈建議手動呼叫 remove 方法。

相關文章