Don't Repeat Yourself

Don't Repeat Yourself (DRY) is a principle of software development aimed at reducing repetition of all kinds. -- wikipedia

Vec::retain の最適化

先日 rustc に上がっていた PR にこんなものがありました。Vec::retain を最適化したという PR です。他の PR の様子なども見る限り、以前からここは取り組まれていたようですが、この度ようやく修正が入り、Vec::retain が高速化しました。

github.com

実際どのくらい高速化したかと言うと、ベンチマークを見る限りではほとんど倍になっていたケースもあるくらいには速くなっていました。その他でもパフォーマンスが微妙に高くなっていたりと、だいたいのケースで高速化しているように見えました。

この PR では、

  • swap の廃止
  • truncate の廃止

の2つが行われています。truncate の方は、実装の修正の関係で新実装の設計に合わなくなったためについでに削除されたように読めましたが、swap は明確な意図をもって削除されています。この swap の削除の話が、調べてみると意外に奥が深く、読みがいがあったのでそれを紹介しようと思って記事を書いています。

なお、これは免責事項になりますが、筆者は普段高レイヤー言語を扱うことが多く、ポインタの操作やメモリ管理の定石などに詳しいわけではありません。記事には誤りがある可能性があるので、何か誤りがありましたらご連絡ください。

今回この記事では、まず Vec::retain とは何で、旧実装がどのようなアルゴリズムのもと retain を行っていたかを解説します。また、その過程で swaptruncate といった関数を実装レベルで見ていきます。その後、新実装の解説を行い、新実装でどのような最適化が行われたために速度が向上したかについて説明する予定です。

ちなみに、コードリーディングの過程のメモは Zenn のスクラップにて公開済みです。このブログの記事は、スクラップの再編成をしたものになります。

Vec::retain とは

Vec::retain とは、条件に合うベクタ内の要素のみを残す処理を行う便利関数です。たとえば、1〜4の数列を入れてあった際に、偶数のみ残すというフィルタリングを行うことができます。Iteratorfilter に近い関数ではあります。

fn main() {
    let mut vec = vec![1, 2, 3, 4];
    vec.retain(|&x| x % 2 == 0);
    assert_eq!(vec, [2, 4]);
}

フィルタと違うポイントとしては、元の配列を破壊しながら要素を残すので、ミュータブルな関数になります。シグネチャは下記のようになっています。

    pub fn retain<F>(&mut self, mut f: F)
    where
        F: FnMut(&T) -> bool,

旧実装のアルゴリズム

旧実装の問題点の整理

この実装の問題点は Vec::swap という関数にありました。Vec::swap 関数の中では std::ptr::copystd::ptr::copy_nonoverlapping という処理が走ります。copy の方は、C の memmove に相当します。また、copy_nonoverlapping の方は memcpy に相当します。memmove はある特定の条件を満たしてしまうとパフォーマンスが落ちることがあるようです。

旧実装では、Vec::swap は削除した要素の個数回呼び出されてしまうことになります。つまり、std::ptr::copy が削除した個数回分呼ばれてしまいます。str::ptr::copy_nonoverlapping は個数×2回呼ばれてしまいます。後ほど説明しますが、新実装ではここを削ることにしたようです。

旧実装を見ていく

では、これまでの実装はどのようになっていたのでしょうか?1.50 時点のソースコードを貼り付けます。

    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn retain<F>(&mut self, mut f: F)
    where
        F: FnMut(&T) -> bool,
    {
        let len = self.len();
        let mut del = 0;
        {
            let v = &mut **self;

            for i in 0..len {
                if !f(&v[i]) {
                    del += 1;
                } else if del > 0 {
                    v.swap(i - del, i);
                }
            }
        }
        if del > 0 {
            self.truncate(len - del);
        }
    }

ここで単純なベクタを例にとってどのようなアルゴリズムになっているのかを見ていきます。

たとえば、[1, 2, 3, 4] というベクタがあった際に、偶数のみを残すという retain を行うものとします。retain(|x| x % 2 == 0) という関数を書くものとします。この結果残る要素は、[2, 4] となるはずです。ミュータブルなので、元の配列の状況が破壊的に変化する点に注意が必要です。

この際行われる処理をまとめると、下記のようになります。

  • i = 0 でループ
    1. 1 は 1 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1
    2. 配列自体に変化なし。
  • i = 1 でループ
    1. 2 は 2 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
    2. del = 1 より、これは満たされる。
    3. この時点で、i - del = 1 - 1 = 0 番目と、i = 1 番目の要素が swap される。
    4. 配列は [2, 1, 3, 4] となる。
  • i = 2 でループ(配列 = [2, 1, 3, 4])
    1. 3 は 3 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1 + 1 = 2
    2. 配列自体に変化なし。
  • i = 3 でループ(配列 = [2, 1, 3, 4])
    1. 4 は 4 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
    2. del = 2 より、これは満たされる。
    3. この時点で、i - del = 3 - 2 = 1 番目と、i = 3 番目の要素が swap される。
    4. 配列は [2, 4, 3, 1] となる。
  • 次のループはないので、ループを抜ける。
  • del > 0 より、不要になった分を truncate する。

swap は2回、truncate は1回呼び出されていることがわかります。swap は削除する際に呼び出されるので、削除する要素が増えるとその分、swap が呼び出される回数も増えることになります。100個削除するなら、100回 swap が呼ばれることになります。まず、この点を抑えておく必要があります。

では、次に swaptruncate がどのような処理を行っているかについて見ていきましょう。

Vec::swap の内部実装

まず、Vec::swap 関数が何をしているかについて見ます。コードが示すとおりで、指定されたインデックスの要素同士を入れ替えます。

fn main() {
    let mut v = ["a", "b", "c", "d"];
    v.swap(1, 3);
    assert!(v == ["a", "d", "c", "b"]);
}

Vec::swap の内部実装は下記のようになっています。2つのインデックスに紐づく要素の可変なポインタを取得しておき、それを std::ptr::swap に投げ込みます。処理の本体は std::ptr::swap に任されています。

    #[stable(feature = "rust1", since = "1.0.0")]
    #[inline]
    pub fn swap(&mut self, a: usize, b: usize) {
        // Can't take two mutable loans from one vector, so instead just cast
        // them to their raw pointers to do the swap.
        let pa: *mut T = &mut self[a];
        let pb: *mut T = &mut self[b];
        // SAFETY: `pa` and `pb` have been created from safe mutable references and refer
        // to elements in the slice and therefore are guaranteed to be valid and aligned.
        // Note that accessing the elements behind `a` and `b` is checked and will
        // panic when out of bounds.
        unsafe {
            ptr::swap(pa, pb);
        }
    }

std::ptr::swap の実装は下記のようになっています。MaybeUninit という、初期化されていないかもしれない型を使って tmp を用意した後、x, y の入れ替え処理をします。

#[inline]
#[stable(feature = "rust1", since = "1.0.0")]
pub unsafe fn swap<T>(x: *mut T, y: *mut T) {
    // Give ourselves some scratch space to work with.
    // We do not have to worry about drops: `MaybeUninit` does nothing when dropped.
    let mut tmp = MaybeUninit::<T>::uninit();

    // Perform the swap
    // SAFETY: the caller must guarantee that `x` and `y` are
    // valid for writes and properly aligned. `tmp` cannot be
    // overlapping either `x` or `y` because `tmp` was just allocated
    // on the stack as a separate allocated object.
    unsafe {
        copy_nonoverlapping(x, tmp.as_mut_ptr(), 1);
        copy(y, x, 1); // `x` and `y` may overlap
        copy_nonoverlapping(tmp.as_ptr(), y, 1);
    }
}

さて、ここで copy_nonoverlappingcopy という関数が出てきました。ドキュメントを読んでみると、次のような処理をすることがわかります。

  • copy_nonoverlapping: 重なりあわないメモリ領域同士をコピーする。memcpy と同じ。
  • copy: 重なる可能性があるメモリ領域同士をコピーする。memmove と同じ。重なり合わない場合は、実質的な動作は memcpy とおなじになる。

重要なケースは copy の方です。memmove は、コピー元とコピー先のメモリ領域自体に重なりがある場合、処理速度が低下することがあるそうです*1。これは重なっている領域分のコピーを1ビットずつ走らせることに由来します*2

それでは、最後に合計で何回 memcpymemmove が走っているのかを見てみましょう。先ほど書いた処理の流れの、swap を copycopy_nonoverlapping に読み替えてみます。

  • i = 0 でループ
    1. 1 は 1 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1
    2. 配列自体に変化なし。
  • i = 1 でループ
    1. 2 は 2 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
    2. del = 1 より、これは満たされる。
    3. この時点で、i - del = 1 - 1 = 0 番目と、i = 1 番目の要素が swap される。copy は1回、copy_nonoverlapping は2回走る。
    4. 配列は [2, 1, 3, 4] となる。
  • i = 2 でループ(配列 = [2, 1, 3, 4])
    1. 3 は 3 % 2 == 0 を満たさないので、del + 1 される。この時点で del = 1 + 1 = 2
    2. 配列自体に変化なし。
  • i = 3 でループ(配列 = [2, 1, 3, 4])
    1. 4 は 4 % 2 == 0 を満たすので、次の del > 0 の判定に移る。
    2. del = 2 より、これは満たされる。
    3. この時点で、i - del = 3 - 2 = 1 番目と、i = 3 番目の要素が swap される。copy は1回、copy_nonoverlapping は2回走る。
    4. 配列は [2, 4, 3, 1] となる。
  • 次のループはないので、ループを抜ける。
  • del > 0 より、不要になった分を truncate する。

つまり、copy は2回走り、copy_nonoverlapping は4回走っていることがわかります。

Vec::truncate の内部実装

truncate 関数は指定した長さの配列に現在の配列を直すというものです。[1, 2, 3, 4]の配列に対して、truncate(2) をすると、左から2要素を残して他はカットするイメージです。

内部実装は下記のようになっています。指定した長さ以降の配列の要素のスライスを一時的に確保しておき、ベクタの長さを指定の長さに削ったあと、一時的に確保したスライスのメモリ領域を解放する処理を行っているようです。

    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn truncate(&mut self, len: usize) {
        // This is safe because:
        //
        // * the slice passed to `drop_in_place` is valid; the `len > self.len`
        //   case avoids creating an invalid slice, and
        // * the `len` of the vector is shrunk before calling `drop_in_place`,
        //   such that no value will be dropped twice in case `drop_in_place`
        //   were to panic once (if it panics twice, the program aborts).
        unsafe {
            if len > self.len {
                return;
            }
            let remaining_len = self.len - len;
            let s = ptr::slice_from_raw_parts_mut(self.as_mut_ptr().add(len), remaining_len);
            self.len = len;
            ptr::drop_in_place(s);
        }
    }

truncate には drop_in_place が使用されているという点がポイントだと思います。これによって、truncate によって残らなかった領域のメモリリークを防いでいます。

新実装のアルゴリズム

新実装の狙いと方針

新実装では、先ほど説明した swap 関数内で呼び出される copycopy_nonoverlapping の回数削減が行われます。とくに劇的に削減されるのは copy の呼び出し回数です。これが、不要だった要素の個数に関係なく1回に削減されています。

これは、メモリの重なりが多かった場合というワーストケースに劇的に効くのはもちろん、特段そうしたワーストケースでない実装においても copy の回数を純粋に減らせることにつながるため、大幅な速度アップが期待できます。実際、ベンチマークでも大幅な速度アップが見込めていそうな結果が出ています。

大まかな方針としては

  • retain 判定の結果、不要となった要素を swap しながら truncate の範囲外に追いやるという手法ではなく、スワップはするものの不要だった要素は都度 drop しておくことにした。
  • swap の際には swap 関数は呼び出さず、copy_nonoverlapping を呼び出すようにした。不要な要素は drop されるので、メモリが重ならないことを保証できるためこれを使用できると思われる。
  • copy は全体の retain 判定がすべて終了した後、後述する構造体の drop のタイミングに一度だけ行う。

新実装を見ていく

新実装をまずは貼ります。

    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn retain<F>(&mut self, mut f: F)
    where
        F: FnMut(&T) -> bool,
    {
        let original_len = self.len();
        // Avoid double drop if the drop guard is not executed,
        // since we may make some holes during the process.
        unsafe { self.set_len(0) };

        // Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked]
        //      |<-              processed len   ->| ^- next to check
        //                  |<-  deleted cnt     ->|
        //      |<-              original_len                          ->|
        // Kept: Elements which predicate returns true on.
        // Hole: Moved or dropped element slot.
        // Unchecked: Unchecked valid elements.
        //
        // This drop guard will be invoked when predicate or `drop` of element panicked.
        // It shifts unchecked elements to cover holes and `set_len` to the correct length.
        // In cases when predicate and `drop` never panick, it will be optimized out.
        struct BackshiftOnDrop<'a, T, A: Allocator> {
            v: &'a mut Vec<T, A>,
            processed_len: usize,
            deleted_cnt: usize,
            original_len: usize,
        }

        impl<T, A: Allocator> Drop for BackshiftOnDrop<'_, T, A> {
            fn drop(&mut self) {
                if self.deleted_cnt > 0 {
                    // SAFETY: Trailing unchecked items must be valid since we never touch them.
                    unsafe {
                        ptr::copy(
                            self.v.as_ptr().add(self.processed_len),
                            self.v.as_mut_ptr().add(self.processed_len - self.deleted_cnt),
                            self.original_len - self.processed_len,
                        );
                    }
                }
                // SAFETY: After filling holes, all items are in contiguous memory.
                unsafe {
                    self.v.set_len(self.original_len - self.deleted_cnt);
                }
            }
        }

        let mut g = BackshiftOnDrop { v: self, processed_len: 0, deleted_cnt: 0, original_len };

        while g.processed_len < original_len {
            // SAFETY: Unchecked element must be valid.
            let cur = unsafe { &mut *g.v.as_mut_ptr().add(g.processed_len) };
            if !f(cur) {
                // Advance early to avoid double drop if `drop_in_place` panicked.
                g.processed_len += 1;
                g.deleted_cnt += 1;
                // SAFETY: We never touch this element again after dropped.
                unsafe { ptr::drop_in_place(cur) };
                // We already advanced the counter.
                continue;
            }
            if g.deleted_cnt > 0 {
                // SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with current element.
                // We use copy for move, and never touch this element again.
                unsafe {
                    let hole_slot = g.v.as_mut_ptr().add(g.processed_len - g.deleted_cnt);
                    ptr::copy_nonoverlapping(cur, hole_slot, 1);
                }
            }
            g.processed_len += 1;
        }

        // All item are processed. This can be optimized to `set_len` by LLVM.
        drop(g);
    }

大幅に実装量は増えていますが、まず構造体が追加されたことと、その構造体に対する Drop トレイとの実装が追加されたために行数が増えています。この構造体は今回の新実装の中核を担っています。

大雑把なアルゴリズムは下記のようになります。

  1. BackshiftOnDrop を作る。
  2. 1で作った構造体の processed_lenoriginal_len (元の配列の長さ) をこえるまでは、ループ処理を回し続ける。
    1. retain の条件に一致しない場合は、その要素を drop しておく。
    2. 要素を削除したカウンタが0より大きければ、削除分を反映した配列の状態を move しておく。
  3. 1を drop する。drop 時に後処理として、下記2つが走る。
    1. 要素を削除したカウンタが0より大きければ、処理した分を copy する。
    2. ベクタ自身の持つサイズを現状のものに調整する。

BackshiftOnDrop は3つのフィールドを持ちます。

  • processed_len: 現在の retain の条件の判定がどこまで進められたかを記録する。
  • deleted_cnt: 不要判定され削除(= drop)された要素の個数を記録する。
  • original_len: もともとのベクタの長さを保持する。

また、コメントにもある通り、ベクタの走査中には次のようなステータスを1つ1つの要素に概念的に適用します。これらの用語は後々アルゴリズムの説明の際に用いられます。一方で、実装には現れてきません。

  • Kept: retain 判定の結果、保持されることが決定された状態の要素を指す。
  • Hole: retain 判定の結果、drop された状態の要素を指す。
  • Unchecked: まだ走査されていない状態の要素を指す。

この実装を、再度 [1, 2, 3, 4] というベクタに対して偶数を retain するという例に対して適用してみましょう。

まず、今回からポインタに対する操作に変わるため、便宜的にアドレスを決定しておきます。下記の決め事はあくまで説明のための例であり、実際の実行環境とは異なる状態になっている可能性があります。

address 0x7ffe4d8f54a0 0x7ffe4d8f54a4 0x7ffe4d8f54a8 0x7ffe4d8f54ac
index 0 1 2 3
value 1i32 2i32 3i32 4i32

これらを元にコードをなぞって見ていくと、下記のような手順で処理が進んでいくことがわかります。

  1. 最初の g を作る。processed_len = 0, deleted_cnt = 0, original_len = 4, 配列: [1, 2, 3, 4]
  2. 1回目の while (processed_len = 0 < original_len = 4), 配列 [1 (Unchecked), 2 (Unchecked), 3 (Unchecked), 4 (Unchecked)]
    1. cur を作る。cur = g.v.as_mut_ptr().add(0) = 0x7ffe4d8f54a0
    2. cur の示す先の値は1i32なので、最初の if ブロックの条件を満たす。processed_len = 1, deleted_cnt = 1 となる。また、curdrop される。配列は [(Hole), 2 (Unchecked), 3 (Unchecked), 4 (Unchecked)] になっているはず。
    3. 次のループに飛ぶ。
  3. 2回目の while (processed_len = 1 < original_len = 4, deleted_cnt = 1), 配列: [(Hole), 2 (Unchecked), 3(Unchecked), 4 (Unchecked)]
    1. cur を作る。cur = g.v.as_mut_ptr().add(1) = 0x7ffe4d8f54a4
    2. cur の示す先の値は 2i32 なので、最初の if ブロックの条件は満たさない。次。
    3. deleted_cnt = 1 より、条件を満たす。
      1. hole_slotg.v.as_mut_ptr().add(1 - 1) = g.v.as_mut_ptr().add(0) = 0x7ffe4d8f54a0(さっき drop したところ)。
      2. curhole_slotcopy_nonoverlapping する。つまり、[2 (Kept), (Hole), 3 (Unchecked), 4 (Unchecked)] となっているはず。
    4. processed_len = 2
  4. 3回目の while (processed_len = 2 < original_len = 4, deleted_cnt = 1), 配列: [2 (Kept), (Hole), 3 (Unchecked), 4 (Unchecked)]
    1. cur を作る。cur = g.v.as_mut_ptr().add(2) = 0x7ffe4d8f54a8
    2. cur の示す先の値は3i32なので、最初の if ブロックの条件を満たす。processed_len = 3, deleted_cnt = 2 となる。また、curdrop される。配列は [2 (Kept), (Hole), (Hole), 4 (Unchecked)] になっているはず。
    3. 次のループに飛ぶ。
  5. 4回目の while (processed_len = 3 < original_len = 4, deleted_cnt = 2), 配列: [2 (Kept), (Hole), (Hole), 4 (Unchecked)]
    1. cur を作る。cur = g.v.as_mut_ptr().add(3) = 0x7ffe4d8f54ac
    2. cur の示す先の値は 4i32 なので、最初の if ブロックの条件は満たさない。次。
    3. deleted_cnt = 2 より、条件を満たす。
      1. hole_slotg.v.as_mut_ptr().add(3 - 2) = g.v.as_mut_ptr().add(1) = 0x7ffe4d8f54a4
      2. curhole_slotcopy_nonoverlapping する。つまり、[2 (Kept), 4 (Kept), (Hole), (Hole)] となっているはず。
    4. processed_len = 4
  6. 5回目のループは条件を満たさずできない。
  7. drop 処理を行う。配列の状態は、[2 (Kept), 4 (Kept), (Hole), (Hole)]
    1. processed_len = 4, deleted_cnt = 2
    2. 終わった時点でのメモリの状態は、0x7ffe4d8f54a0 = 2, 0x7ffe4d8f54a4 =4, 0x7ffe4d8f54a8 = Hole, 0x7ffe4d8f54ac = Hole
    3. copy(g.v.as_ptr().add(4) = 0x7ffe4d8f54b0, g.v.as_mut_ptr().add(2) = 0x7ffe4d8f54a8, 0) が実行される。ただし、count に 0 が入っているので、実質何もしない*3
    4. Vec の len は2がセットされる。[2, 4] が len に含まれることになった。すでに drop 済みなので、不要と判断された要素のメモリ領域がメモリリークすることはない。

重要なポイントは、copy の回数が減っている点です。最後の 7-2 でしか copy は走らなくなっています。今回の場合は、不要だった要素数が2個しかないので実質 1/2 ですが、仮に不要だった要素が100個だった場合であっても、1回しかコピーは走りません。ということは、コピーの回数は 1/100 ということになります。これは劇的な改善だと思われます。

また、copy_nonoverlapping も2回に減っています。

したがって、コピー操作全般の回数が、旧実装(6回)と比較すると半分(新実装では3回)に減っていることがわかります。不要な要素の多さによっては減る回数がさらに増える可能性があることもわかりました。

まとめ

  • Vec::retain の実装が最適化された。
  • swap を行いながら不要になった要素を範囲外に寄せておき、最後に範囲外の要素を全部 drop するという戦略を取らなくした。
  • 不要な要素を発見したら都度 drop しておき、メモリ領域同士が重ならないことを保証しながらスワップ操作をできるようにした。こうすることで、memcpy を安心して走らせることができるし、memmove の速度低下の可能性も下げられるようになった。

Rust の Vec は実はこうしたポインタ操作の宝庫です。筆者のように、普段高レイヤー言語を触っていて参照もポインタもほぼ馴染みがない、というエンジニアにとって、Rust の Vec は絶好の教材だと思います。他の実装も読んでみて、どういった操作をしているのか知りたくなってきました。

*1:これは Rust でもそうなのかをベンチを取ってみる必要がありそうです。それは後日で。

*2:記事を参考にしただけなので、最近の実装がどうなのかまでは調べていません: https://fd0.hatenablog.jp/entry/20071222/p1

*3:例が悪く、processed_len と deleted_cnt が非対称になる例を見せるべきだった可能性はある。