BAT 經典演算法筆試題 —— 磁碟多路歸併排序

碼洞發表於2019-01-18

在 LevelDB 資料庫中高層資料下沉到低層時需要經歷一次 Major Compaction,將高層檔案的有序鍵值對和低層檔案的多個有序鍵值對進行歸併排序。磁碟多路歸併排序演算法的輸入是來自多個磁碟檔案的有序鍵值對,在記憶體中將這些檔案的鍵值對進行排序,然後輸出到一到多個新的磁碟檔案中。

圖片

多路歸併排序在大資料領域也是常用的演算法,常用於海量資料排序。當資料量特別大時,這些資料無法被單個機器記憶體容納,它需要被切分位多個集合分別由不同的機器進行記憶體排序(map 過程),然後再進行多路歸併演算法將來自多個不同機器的資料進行排序(reduce 過程),這是流式多路歸併排序,為什麼說是流式排序呢,因為資料來源來源於網路套接字。

圖片

多路歸併排序的優勢在於記憶體消耗極低,它的記憶體佔用和輸入檔案的數量成正比,和資料總量無關,資料總量只會線性正比影響排序的時間。

下面我們來親自實現一下磁碟多路歸併演算法,為什麼是磁碟,因為它的輸入來自磁碟檔案。

演算法思路

我們需要在記憶體裡維護一個有序陣列。每個輸入檔案當前最小的元素作為一個元素放在陣列裡。陣列按照元素的大小保持排序狀態。

圖片

接下來我們開始進入迴圈,迴圈的邏輯總是從最小的元素下手,在其所在的檔案取出下一個元素,和當前陣列中的元素進行比較。根據比較結果進行不同的處理,這裡我們使用二分查詢演算法進行快速比較。注意每個輸入檔案裡面的元素都是有序的。

  1. 如果取出來的元素和當前陣列中的最小元素相等,那麼就可以直接將這個元素輸出。再繼續下一輪迴圈。不可能取出比當前陣列最小元素還要小的元素,因為輸入檔案本身也是有序的。

圖片

  1. 否則就需要將元素插入到當前的陣列中的指定位置,繼續保持陣列有序。然後將陣列中當前最小的元素輸出並移除。再進行下一輪迴圈。

圖片
3. 如果遇到檔案結尾,那就無法繼續呼叫 next() 方法了,這時可以直接將陣列中的最小元素輸出並移除,陣列也跟著變小了。再進行下一輪迴圈。當陣列空了,說明所有的檔案都處理完了,演算法就可以結束了。

圖片
值得注意的是,陣列中永遠不會存在同一個檔案的兩個元素,如此才保證了陣列的長度不會超過輸入檔案的數量,同時它也不會把沒有結尾的檔案擠出陣列導致漏排序的問題。

二分查詢

需要特別注意的是Java 內建了二分查詢演算法在使用上比較精巧。

public class Collections {
  ...
  public static <T> int binarySearch(List<T> list, T key) {
    ...
    if (found) {
      return index;
    } else {
      return -(insertIndex+1);
    }
  }
  ...
}
複製程式碼

如果 key 可以在 list 中找到,那就直接返回相應的位置。如果找不到,它會返回負數,還不是簡單的 -1,這個負數指明瞭插入的位置,也就是說在這個位置插入 key,陣列將可以繼續保持有序。

比如 binarySearch 返回了 index=-1,那麼 insertIndex 就是 -(index+1),也就是 0,插入點在陣列開頭。如果返回了 index=-size-1,那麼 insertIndex 就是 size,是陣列末尾。其它負數會插入陣列中間。

圖片

輸入檔案類

對於每一個輸入檔案都會建立一個 MergeSource 物件,它提供了 hasNext() 和 next() 方法用於判斷和獲取下一個元素。注意輸入檔案是有序的,下一個元素就是當前輸入檔案最小的元素。 hasNext() 方法負責讀取下一行並快取在 cachedLine 變數中,呼叫 next() 方法將 cachedLine 變數轉換成整數並返回。

class MergeSource implements Closeable {
	private BufferedReader reader;
	private String cachedLine;
	private String filename;

	public MergeSource(String filename) {
		this.filename = filename;
		try {
            FileReader fr = new FileReader(filename);
			this.reader = new BufferedReader(fr);
		} catch (FileNotFoundException e) {
		}
	}

	public boolean hasNext() {
		String line;
		try {
			line = this.reader.readLine();
			if (line == null || line.isEmpty()) {
				return false;
			}
			this.cachedLine = line.trim();
			return true;
		} catch (IOException e) {
		}
		return false;
	}

	public int next() {
		if (this.cachedLine == null) {
			if (!hasNext()) {
				throw new IllegalStateException("no content");
			}
		}
		int num = Integer.parseInt(this.cachedLine);
		this.cachedLine = null;
		return num;
	}

	@Override
	public void close() throws IOException {
		this.reader.close();
	}
}
複製程式碼

記憶體有序陣列元素類

在排序前先把這個陣列準備好,將每個輸入檔案的最小元素放入陣列,並排序。

class Bin implements Comparable<Bin> {
	int num;
	MergeSource source;

	Bin(MergeSource source, int num) {
		this.source = source;
		this.num = num;
	}

	@Override
	public int compareTo(Bin o) {
		return this.num - o.num;
	}

}

List<Bin> prepare() {
  	List<Bin> bins = new ArrayList<>();
	for (MergeSource source : sources) {
		Bin newBin = new Bin(source, source.next());
		bins.add(newBin);
	}
    Collections.sort(bins);
    return bins;
}
複製程式碼

輸出檔案類

關閉輸出檔案時注意要先 flush(),避免丟失 PrintWriter 中緩衝的內容。

class MergeOut implements Closeable {
	private PrintWriter writer;

	public MergeOut(String filename) {
		try {
            FileOutputStream out = new FileOutputStream(filename);
			this.writer = new PrintWriter(out);
		} catch (FileNotFoundException e) {
		}
	}

	public void write(Bin bin) {
		writer.println(bin.num);
	}

	@Override
	public void close() throws IOException {
		writer.flush();
		writer.close();
	}
}
複製程式碼

準備輸入檔案的內容

下面我們來生成一系列輸入檔案,每個輸入檔案中包含一堆隨機整數。一共生成 n 個檔案,每個檔案的整數數量在 minEntries 到 minEntries 之間。返回所有輸入檔案的檔名列表。

List<String> generateFiles(int n, int minEntries, int maxEntries) {
	List<String> files = new ArrayList<>();
	for (int i = 0; i < n; i++) {
		String filename = "input-" + i + ".txt";
		PrintWriter writer;
		try {
			writer = new PrintWriter(new FileOutputStream(filename));
            ThreadLocalRandom rand = ThreadLocalRandom.current();
			int entries = rand.nextInt(minEntries, maxEntries);
			List<Integer> nums = new ArrayList<>();
			for (int k = 0; k < entries; k++) {
				int num = rand.nextInt(10000000);
				nums.add(num);
			}
			Collections.sort(nums);
			for (int num : nums) {
				writer.println(num);
			}
			writer.flush();
			writer.close();
		} catch (FileNotFoundException e) {
		}
		files.add(filename);
	}
	return files;
}
複製程式碼

排序演算法

萬事俱備,只欠東風。將上面的類都準備好之後,排序演算法很簡單,程式碼量非常少。對照上面演算法思路來理解下面的演算法就很容易了。

public void sort() {
	List<Bin> bins = prepare();
	while (true) {
        // 取陣列中最小的元素
		MergeSource current = bins.get(0).source;
		if (current.hasNext()) {
            // 從輸入檔案中取出下一個元素
			Bin newBin = new Bin(current, current.next());
            // 二分查詢,也就是和陣列中已有元素進行比較
			int index = Collections.binarySearch(bins, newBin);
			if (index == 0) {
                // 演算法思路情況1
				this.out.write(newBin);
			} else {
                // 演算法思路情況2
				if (index < 0) {
					index = -(index+1);
				}
				bins.add(index, newBin);
				Bin minBin = bins.remove(0);
				this.out.write(minBin);
			}
		} else {
            // 演算法思路情況3:遇到檔案尾
			Bin minBin = bins.remove(0);
			this.out.write(minBin);
			if (bins.isEmpty()) {
				break;
			}
		}
	}
}
複製程式碼

全部程式碼

讀者可以直接將下面的程式碼拷貝貼上到 IDE 中執行。

package leetcode;

import java.io.BufferedReader;
import java.io.Closeable;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

public class DiskMergeSort implements Closeable {

	public static List<String> generateFiles(int n, int minEntries, int maxEntries) {
		List<String> files = new ArrayList<>();
		for (int i = 0; i < n; i++) {
			String filename = "input-" + i + ".txt";
			PrintWriter writer;
			try {
				writer = new PrintWriter(new FileOutputStream(filename));
				int entries = ThreadLocalRandom.current().nextInt(minEntries, maxEntries);
				List<Integer> nums = new ArrayList<>();
				for (int k = 0; k < entries; k++) {
					int num = ThreadLocalRandom.current().nextInt(10000000);
					nums.add(num);
				}
				Collections.sort(nums);
				for (int num : nums) {
					writer.println(num);
				}
				writer.close();
			} catch (FileNotFoundException e) {
			}
			files.add(filename);
		}
		return files;
	}

	private List<MergeSource> sources;
	private MergeOut out;

	public DiskMergeSort(List<String> files, String outFilename) {
		this.sources = new ArrayList<>();
		for (String filename : files) {
			this.sources.add(new MergeSource(filename));
		}
		this.out = new MergeOut(outFilename);
	}

	static class MergeOut implements Closeable {
		private PrintWriter writer;

		public MergeOut(String filename) {
			try {
				this.writer = new PrintWriter(new FileOutputStream(filename));
			} catch (FileNotFoundException e) {
			}
		}

		public void write(Bin bin) {
			writer.println(bin.num);
		}

		@Override
		public void close() throws IOException {
			writer.flush();
			writer.close();
		}
	}

	static class MergeSource implements Closeable {
		private BufferedReader reader;
		private String cachedLine;

		public MergeSource(String filename) {
			try {
				FileReader fr = new FileReader(filename);
				this.reader = new BufferedReader(fr);
			} catch (FileNotFoundException e) {
			}
		}

		public boolean hasNext() {
			String line;
			try {
				line = this.reader.readLine();
				if (line == null || line.isEmpty()) {
					return false;
				}
				this.cachedLine = line.trim();
				return true;
			} catch (IOException e) {
			}
			return false;
		}

		public int next() {
			if (this.cachedLine == null) {
				if (!hasNext()) {
					throw new IllegalStateException("no content");
				}
			}
			int num = Integer.parseInt(this.cachedLine);
			this.cachedLine = null;
			return num;
		}

		@Override
		public void close() throws IOException {
			this.reader.close();
		}
	}

	static class Bin implements Comparable<Bin> {
		int num;
		MergeSource source;

		Bin(MergeSource source, int num) {
			this.source = source;
			this.num = num;
		}

		@Override
		public int compareTo(Bin o) {
			return this.num - o.num;
		}
	}

	public List<Bin> prepare() {
		List<Bin> bins = new ArrayList<>();
		for (MergeSource source : sources) {
			Bin newBin = new Bin(source, source.next());
			bins.add(newBin);
		}
		Collections.sort(bins);
		return bins;
	}

	public void sort() {
		List<Bin> bins = prepare();
		while (true) {
			MergeSource current = bins.get(0).source;
			if (current.hasNext()) {
				Bin newBin = new Bin(current, current.next());
				int index = Collections.binarySearch(bins, newBin);
				if (index == 0 || index == -1) {
					this.out.write(newBin);
					if (index == -1) {
						throw new IllegalStateException("impossible");
					}
				} else {
					if (index < 0) {
						index = -index - 1;
					}
					bins.add(index, newBin);
					Bin minBin = bins.remove(0);
					this.out.write(minBin);
				}
			} else {
				Bin minBin = bins.remove(0);
				this.out.write(minBin);
				if (bins.isEmpty()) {
					break;
				}
			}
		}
	}

	@Override
	public void close() throws IOException {
		for (MergeSource source : sources) {
			source.close();
		}
		this.out.close();
	}

	public static void main(String[] args) throws IOException {
		List<String> inputs = DiskMergeSort.generateFiles(100, 10000, 20000);
		// 執行多次看演算法耗時
		for (int i = 0; i < 20; i++) {
			DiskMergeSort sorter = new DiskMergeSort(inputs, "output.txt");
			long start = System.currentTimeMillis();
			sorter.sort();
			long duration = System.currentTimeMillis() - start;
			System.out.printf("%dms\n", duration);
			sorter.close();
		}
	}
}
複製程式碼

本演算法還有一個小缺陷,那就是如果輸入檔案數量非常多,那麼記憶體中的陣列就會特別大,對陣列的插入刪除操作肯定會很耗時,這時可以考慮使用 TreeSet 來代替陣列,讀者們可以自行嘗試一下。

BAT 經典演算法筆試題 —— 磁碟多路歸併排序

相關文章