STL原始碼之rotate函式結合圖和例項分析

FreeeLinux發表於2017-01-04


今天看 STL 原始碼看到 rotate() 函式這一塊,該函式就是將 [first, middle) 的元素和 [middle, last) 的元素互換。middle 的元素會成為容器的第一個元素。如果有個數字序列 {1, 2, 3, 4, 5, 6, 7},對元素 3 做旋轉操作,會形成 {3, 4, 5, 6, 7, 1, 2}。其實這就是我們平時說的左旋轉字串,只不過泛型化了而已。它可以旋轉的內容不止字串,其他迭代器型別都可以。


三種方法的分析: 

演算法1(分組交換):(來自網友:雁過無痕)
若a長度大於b,將ab分成a0a1b,交換a0和b,得ba1a0,只需再交換a1 和a0。若a長度小於b,將ab分成ab0b1,交換a和b0,得b0ab1,只需再交換a 和b1。不斷將陣列劃分和交換,直到不能再劃分為止。分組過程與求最大公約數很相似。

程式碼如下:

emplate <class ForwardIterator, class Distance>
// Distance型別僅僅對於random iterator的實現版本有意義,但為了便於上層程式碼便於呼叫,所以使用// 了同樣的簽名。
void __rotate(ForwardIterator first, ForwardIterator middle, ForwardIterator last, Distance*, forward_iterator_tag)
{
    for (ForwardIterator i = middle; ;) {
        // iter_swap用於交換兩個iterator所指向的內容。
        // 也可以這樣寫:swap(*first, *i);
        iter_swap(first, i);
        ++first;
        ++i;
        if (first == middle) {
            // first和i同時到達末尾,元素交換結束,返回。
            if (i == last)
                return;
            // first首先到達末尾,說明A的長度小於B。
            middle = i;
        }
        // i首先到達末尾,說明A的長度大於B。
        else if (i == last)
            i = middle;
    }
}

 
演算法2 (三次反轉)
利用ba=(br)r(ar)r=(arbr)r,先分別反轉a、b,最後再對所有元素進行一次反轉。

程式碼如下:

template <class BidirectionalIterator, class Distance>
void __rotate(BidirectionalIterator first, BidirectionalIterator middle, BidirectionalIterator last, Distance*, bidirectional_iterator_tag)
{
    // 翻轉A
    reverse(first, middle);
    // 翻轉B
    reverse(middle, last);
    // 翻轉A'B'
    reverse(first, last);
}

演算法3 (使用gcd)(分析來自網友:陳覃)
__gcd是求兩個數的最大公約數,也是迴圈位移的遍數。
舉個例子來說明演算法過程,陣列123456789,把123翻轉到右邊,*first=1,*last=9,*middle=4;
要旋轉字串(123)的長度為3,字串長度為9,3和9的最大公約數為3,因此需要翻轉3遍;
第一遍從*(initial+shift)=6開始,6移到3的位置,9移到6的位置,下一個位置是ptr2 = first + (shift - (last - ptr2))=0+(3-(8-8))=3,不滿足ptr2 != initial的條件,退出迴圈,然後*ptr1 = value,即把數字3移動到數字9的位置,從而完成了3,6,9三個數字的位移,下面的2遍迴圈則分別完成2,5,8和1,4,76個數字的位移,最後得到最終結果456789123。

對於輾轉相除法更詳細的證明可以參考我以前的部落格:輾轉相除法、埃拉托色尼篩選法、牛頓迭代法證明與C++實現

整個演算法過程可用下圖表示:      


程式碼如下:

template <class RandomAccessIterator, class Distance>
void __rotate(RandomAccessIterator first, RandomAccessIterator middle, RandomAccessIterator last, Distance*, random_access_iterator_tag)
{
    // gcd是求最大公約數的函式。
    Distance n = __gcd(last - first, middle - first);

    while (n--)   //注意這裡是n--,我因為沒看見這個n--,時間浪費了半天
        // 需要執行__rotate_cycle n次。
        __rotate_cycle(first, last, first + n, middle - first, value_type(first));
}

template <class RandomAccessIterator, class Distance, class T>
void __rotate_cycle(RandomAccessIterator first, RandomAccessIterator last, RandomAccessIterator initial, Distance shift, T*)
{
    T value = *initial;
    RandomAccessIterator ptr1 = initial;
    RandomAccessIterator ptr2 = ptr1 + shift;

    while (ptr2 != initial) {
        *ptr1 = *ptr2;
        ptr1 = ptr2;
        if (last - ptr2 > shift)
            ptr2 += shift;
        else
            ptr2 = first + (shift - (last - ptr2));
    }

    *ptr1 = value;
}

template <class EuclideanRingElement>
EuclideanRingElement __gcd(EuclideanRingElement m, EuclideanRingElement n)
{
    while (n != 0) {
        EuclideanRingElement t = m % n;
        m = n;
        n = t;
    }

    return m;
}

由於前兩種比較簡單,在這裡我僅實現了第三種作為練習,一次性AC:)

#include <iostream>
#include <assert.h>

int calc_gcd(int m, int n)
{
    if(m < n)
        std::swap(m, n); 
    if(n == 0)
        return m;
    calc_gcd(n, m%n);
}

template <typename T>
void cycle_rotate(T *arr, int *first, int *last, int *initial, int rotate_num)
{
    T value = *initial;
    T *ptr1 = initial, *ptr2 = ptr1 + rotate_num;
    while(ptr2 != initial){
        *ptr1 = *ptr2;
        ptr1 = ptr2;
        if(last - ptr2 >= rotate_num)  //可以等於,因為是下標
            ptr2 += rotate_num;
        else
            ptr2 = first + (rotate_num - (last - ptr2)) - 1; //注意要減一,因為我這裡用的是下標
    }   
    *ptr1 = value;
}

template <typename T>
void rotate(T* arr, int start, int end, int rotate_num)
{
    assert(start >= 0 && rotate_num > start && rotate_num <= end+1);
  int gcd = calc_gcd(end-start+1, rotate_num);
    while(gcd--)
        cycle_rotate(arr, arr+start, arr+end, arr+start+gcd, rotate_num);
}

int main()
{
    int array[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
    int len = sizeof(array) / sizeof(int);
    rotate(array, 0, len-1, 5);
    for(auto i : array)
        std::cout<<i<<' ';
    std::cout<<std::endl;

    return 0;
}

輸出:


注:我的程式碼中,輸入分別是開始下標,結束下標,旋轉個數,所以和 STL 稍有不同。

相關文章