遞迴、迭代和動態規劃:以九宮格鍵盤為例

sorra發表於2022-06-02

本文記述筆者的演算法學習心得,由於筆者不熟悉演算法,如有錯誤請不吝指正。

九宮格鍵盤問題

給定一個從數字到字母的對映表:1 => [a, b, c], 2 => [d, e, f], 3=> [g, h, i], ......,實現一個函式List<String> calculate(String input),若使用者輸入任意的數字串,要依據對映錶轉換每一個數字,返回所有可能的轉換結果。例如,若輸入“12”會得到ad, ae, af, bd, be, bf, cd, ce, cf這9個結果;若輸入“123”會得到27個結果。

以上是此題的基礎版,還有一個難度加強版:對映表的鍵還可以是數字組合,如12 => [x], 23 => [y]。若輸入“12”會有10個結果,還會有x這1個結果(總共10個結果);若輸入“123”,不但有之前的27個結果,還會有xg, xh, xi, ay, by, cy這6個結果(總共33個結果)。

基礎版的解法

可以用遞迴法也可以用迭代法,兩種方法是等價的。以下是演算法思想和實現程式碼。

遞迴法

遞迴法的思想是持續把問題降解為子問題,直到子問題足夠小。例如,calculate("123")等價於calculate("1") * calculate("23"),而calculate("23")等價於calculate("2") * calculate("3")。這裡引入了一種特殊的集合乘法,能把如[a, b, c]和[d, e, f]這樣的兩個集合相乘得到9個結果(即輸入“12”對應的結果),這種乘法可以用一個函式來實現。如此就清晰地表述了遞迴法的演算法框架。

實現程式碼為

static List<String> calculate(String input) {
  if (input.isEmpty()) {
    return Collections.emptyList();
  }

  String key = String.valueOf(input.charAt(0));
  List<String> values = mappings.get(key);
  String substring = input.substring(1);

  if (substring.isEmpty()) {
    return values;
  } else {
    return product(values, calculate(substring));
  }
}

// 乘法
static List<String> product(List<String> lefts, List<String> rights) {
  List<String> results = new ArrayList<>();
  for (String left : lefts) {
    for (String right : rights) {
      results.add(left + right);
    }
  }
  return results;
}

迭代法

迭代法的思想是依次處理輸入的每一位,用已得結果集表示當前狀態,記住和更新當前狀態。例如,calculate(“123”)的處理過程是,第一步處理”1”,得到初始狀態[a, b, c],第二步處理”2”,讓當前狀態[a, b, c]乘以“2”對應的[d, e, f],得到新狀態[ad, ae, af, bd, be, bf, cd, ce, cf],第三步處理”3”,讓當前狀態乘以”3”對應的[g, h, i],得到新狀態,此時輸入值已全部處理完成。

實現程式碼為

static List<String> calculate(String input) {
  List<String> results = Collections.emptyList();

  for (int i = 0; i < input.length(); i++) {
    String key = String.valueOf(input.charAt(i));
    List<String> values = mappings.get(key);

    if (results.isEmpty()) {
      results = values;
    } else {
      results = product(results, values);
    }
  }

  return results;
}

static List<String> product(List<String> lefts, List<String> rights) {
  List<String> results = new ArrayList<>();
  for (String left : lefts) {
    for (String right : rights) {
      results.add(left + right);
    }
  }
  return results;
}

難度加強版的解法

基礎版直接把每1位數字作為key,用mappings.get(key)獲取對映值。而難度加強版面對不等長的key,需要換一個方式,遍歷mappings中的每一個key,判斷當前輸入是否以這個key開頭,若是,就採用這個key的對映值,同時從當前輸入去掉這個key,然後繼續處理剩餘輸入。
基礎版的遞迴法略加修改就能得到難度加強版的解法。迭代法不能修改得到新解法,因為每一步可能產生多個步長不同的分支,這些分支所需處理的剩餘輸入是不同的,不能單調迭代,用演算法理論來說,由於這裡的遞迴法不是尾遞迴,因此沒有等價的迭代法版本。可以用棧或佇列將遞迴計算轉換為延續(continuation)來實現迭代計算,其實相當於模擬了遞迴,這個留給讀者自行完成。
其實迭代法可以更高效地實現,但是要使用動態規劃法。

遞迴法

實現程式碼為

static List<String> calculate(String input) {
  if (input.isEmpty()) {
    return Collections.emptyList();
  }

  List<String> results = new ArrayList<>();

  for (String key : mappings.keySet()) {
    if (input.startsWith(key)) {
      List<String> values = mappings.get(key);

      if (values == null) {
        throw new IllegalArgumentException("Unrecognized input");
      }

      String substring = input.substring(key.length());
      if (substring.isEmpty()) {
        results.addAll(values);
      } else {
        results.addAll(product(values, calculate(substring)));
      }
    }
  }

  return results;
}

// 乘法
static List<String> product(List<String> lefts, List<String> rights) {
  List<String> results = new ArrayList<>();
  for (String left : lefts) {
    for (String right : rights) {
      results.add(left + right);
    }
  }
  return results;
}

遞迴法的效能優化

來找找優化點。遞迴過程有很多重複的計算,可以用memoization技術搞一個快取來提速,據測試能提速一倍,快取的資料結構可以用雜湊表或陣列來實現。雜湊表以calculate()的input引數作為鍵,以calculate()的結果作為值;陣列以input.length()作為索引值,以calculate()的結果作為元素值。若輸入規模為n,快取的空間複雜度為O(n)。這個就交給讀者自行實現吧。

results.addAll(product(values, calculate(substring)));這行是可以優化的,product()返回一個新建的list,這個list又被複制到results裡,存在一些記憶體分配和複製的開銷。如果product()能直接輸出到results而不是返回一個list,就不用新建和複製一個list。據測試能提速10%左右。
程式碼修改為

// 呼叫處這麼改
product(values, calculate(substring), results);

// 定義處這麼改
static List<String> product(List<String> lefts, List<String> rights, List<String> results) {
  for (String left : lefts) {
    for (String right : rights) {
      results.add(left + right);
    }
  }
}

String substring = input.substring(key.length());這行好像也可以優化,substring()返回一個新建的子字串,存在一些記憶體分配和複製的開銷。其實,在舊版JVM上有一個優化,substring()是像subList()一樣的檢視物件,直接引用原始字串而不是建立副本,但新版JVM為了防止記憶體洩漏而取消了這一優化(大型原始字串已不再使用,但仍被子字串引用而無法釋放空間)。我們可以手動實現一個StringView類來做同樣的優化,但substring()的呼叫次數不多,因此對實際效能沒有什麼提升。這個技術還是比較有意思的,StringView的實現在文尾的附錄提供。

動態規劃法

動態規劃法能達到和memoization差不多甚至更快的速度,因為它也避免了重複的計算,而且空間複雜度只有O(1)。
《演算法概論》和《演算法設計》這兩本演算法教材都講到,動態規劃不可能用遞迴來實現。動態規劃的思想是從初始狀態出發,一步步產生新狀態,每一個新狀態是從之前的一個或多個狀態得到的,實現起來像一個狀態機,因此不可能用遞迴來表達。對於這道題,可以持續掃描輸入串,每一步向前掃一位,用當前已掃過的字元序列來表示當前狀態,若當前狀態以某個key結尾,則把此key的對映值乘到它所對應的“上一步”的結果集上,把新的結果集儲存到快取。這個“上一步”是哪一步呢?不是簡單地回退一位,而是把剛才掃到的key從當前狀態去掉,就能回退到上一步的狀態(相當於回退了key.length位),由於每一步都把結果集儲存到快取了,因此只要查詢快取就能得到上一步的結果集。快取以回退位數為鍵,由於回退位數最多也不可能超過最長的key的length,因此快取是一個大小不超過這個數的滑動視窗,新的值進來就把最舊的值擠掉,只需常數級空間。

實現程式碼為

static List<String> calculate(String input) {
  String state = "";
  List<String> results = new ArrayList<>();
  Cache cache = new Cache();

  for (int i = 0; i < input.length(); i++) {
    state += input.charAt(i);
    List<String> newResults = new ArrayList<>();

    for (String key : mappings.keySet()) {
      if (state.endsWith(key)) {
        List<String> prevResult = cache.get(key.length());
        List<String> values = mappings.get(key);
        if (prevResult == null) {
          newResults.addAll(values);
        } else {
          newResults.addAll(product(prevResult, values));
        }
      }
    }

    results = newResults;
    cache.put(results);
  }

  return results;
}

static List<String> product(List<String> lefts, List<String> rights) {
  List<String> results = new ArrayList<>();
  for (String left : lefts) {
    for (String right : rights) {
      results.add(left + right);
    }
  }
  return results;
}

// Sliding window
private static class Cache {
  private int maxLength = mappings.keySet().stream().map(String::length).max(Integer::compareTo).get();
  private LinkedList<List<String>> queue = new LinkedList<>();

  List<String> get(int lookBackLength) {
    if (queue.size() < lookBackLength) {
      return null;
    }
    return queue.get(lookBackLength - 1);
  }

  void put(List<String> solutions) {
    if (queue.size() == maxLength) {
      queue.removeLast();
    }
    queue.offerFirst(solutions);
  }
}

附錄

StringView的實現程式碼

class StringView {
  private final String string;
  private final int offset;

  StringView(String string, int offset) {
    if (offset > string.length()) {
      throw new IllegalArgumentException("offset should be within string length");
    }
    this.string = string;
    this.offset = offset;
  }

  StringView subview(int additionalOffset) {
    return new StringView(string, offset + additionalOffset);
  }

  boolean isEmpty() {
    return offset == string.length();
  }

  boolean startsWith(String key) {
    int keyLength = key.length();
    if (string.length() < offset + keyLength) {
      return false;
    }

    for (int i = 0; i < keyLength; i++) {
      if (key.charAt(i) != string.charAt(i + offset)) {
        return false;
      }
    }

    return true;
  }
}

相關文章