Don't Repeat Yourself

Don't Repeat Yourself (DRY) is a principle of software development aimed at reducing repetition of all kinds. -- wikipedia

【競プロノート】bit全探索を学ぶ

全探索にもいろいろ種類があるそうで、今日は bit 全探索という方法を勉強したのでそれについてまとめてみます。まとめないと忘れる。

解きたい問題

この記事に載っている部分和の問題を解きます。説明のために問題文を拝借します🙏🏻

qiita.com

n 個の正の整数 a[0],a[1],…,a[n−1] と正の整数 A が与えられる。これらの整数から何個かの整数を選んで総和が A になるようにすることが可能か判定せよ。可能ならば "YES" と出力し、不可能ならば "NO" と出力せよ。
【制約】 ・1≤n≤100 ・1≤a[i]≤1000 ・1≤A≤10000
【数値例】 1)  n=3  a=(7,5,3)  A=10  答え: YES (7 と 3 を選べばよいです)

2)  n = 2  a=(9,7)  A=6  答え: NO

解き方のアイディア

普通に解こうとすると、まず与えられた文字列を split して、「+」を入れられる位置を全部網羅して…といった解法になると思います。が、それは少しスマートではないように思えます。ここで出てくるのが、bit 全探索という解法です。ただし、bit 全探索は少々コストの大きい計算で、たとえば n = 100 くらいになると計算時間は膨大です。*1

bit 全探索は何をする方法?

今回の実装では、指定した数 n - 1 の部分集合をすべて探索できます (0-indexed です)。具体的には、n = 2 だった場合、{0, 1} についての部分集合を全探索します。つまり、({φ},) {0}, {1}, {0, 1} を全探索してリストアップしてくれるというすぐれものです。

ちなみに、集合 A のすべての部分集合からなる集合を A の冪集合と呼びます。集合 A = {0, 1} があった際の冪集合は、先ほどの結果を利用して P(A) = {φ, {0}, {1}, {0, 1}} と記述できます*2。bit 探索ではこの冪集合を求めていると考えてよいと思います。

通常の全探索であれば計算量は O(4n) になりますが、この bit 全探索を用いれば、計算量は O(3n) に落ちます。

#include <iostream>
#include <vector>
using namespace std;

int main() {
    int n = 2;

    for (int bit = 0; bit < (1 << n); bit++) {
        vector<int> S;
        for (int i = 0; i < n; i++) {
            if (bit & (1 << i)) {
                S.push_back(i);
            }
        }

        cout << bit << ": {";
        for (int i = 0; i < (int)S.size(); i++) {
            cout << S[i] << " ";
        }
        cout << "}" << endl;
    }
}

出力結果は、

$ ./bit_search_prac
0: {}
1: {0 }
2: {1 }
3: {0 1 }

と出てきます。また、n = 5 とすると、集合 A = {0, 1, 2, 3, 4} となります。したがってその冪集合 P(A) は、下記のように出力されます。

$ ./bit_search_prac
0: {}
1: {0 }
2: {1 }
3: {0 1 }
4: {2 }
5: {0 2 }
6: {1 2 }
7: {0 1 2 }
8: {3 }
9: {0 3 }
10: {1 3 }
11: {0 1 3 }
12: {2 3 }
13: {0 2 3 }
14: {1 2 3 }
15: {0 1 2 3 }
16: {4 }
17: {0 4 }
18: {1 4 }
19: {0 1 4 }
20: {2 4 }
21: {0 2 4 }
22: {1 2 4 }
23: {0 1 2 4 }
24: {3 4 }
25: {0 3 4 }
26: {1 3 4 }
27: {0 1 3 4 }
28: {2 3 4 }
29: {0 2 3 4 }
30: {1 2 3 4 }
31: {0 1 2 3 4 }

といったように、すべてのパターンを網羅して出してくれます。初めてみたときちょっと感動しました。1行1行の結果を配列にしてもたせておいて、すべての項を足し上げれば、部分和問題が解けそうな気がしてきますね🙂

実際に使ってみる

まず解いてみた

実装そのものは決して難しくはありません。2重ループを回すだけです。最初のループで部分集合の全探索をし、内側のループで bit の表す集合を求めていくだけです。

#include <iostream>
#include <vector>

using namespace std;

int n;
int a[25];
int A;

int main() {
    cin >> n;
    for (int i = 0; i < n; i++) cin >> a[i];
    cin >> A;

    bool existence = false;
    for (int bit = 0; bit < (1 << n); bit++) {
        int sum = 0;
        for (int i = 0; i < n; i++) {
            if (bit & (1 << i)) {
                sum += a[i];
            }
        }

        if (sum == A) {
            existence = true;
        }
    }

    if (existence) cout << "Yes" << endl; else cout << "No" << endl;
}

この結果、

$ ./total_sum
3
7 5 3
10
Yes
$ ./total_sum
2
9 7 
6
No

記事と同じような期待通りの結果が得られます。

しかし、中でイマイチどういった操作を行っているのか腑に落ちませんでした。そもそも今回のように bit 演算を利用するとなぜ、部分集合を出力できるのでしょうか?原理を知る旅に出ましょう。

原理がわからない!

わからない点を整理します。

  • 最初のループの条件がいきなり黒魔術: for (int bit = 0; bit < (1 << n); bit++) がわからない。具体的には、(1 << n) で何が得られているのか?
  • if (bit & (1 << i)): 条件分岐に使用しているので、たとえば真が存在するなら bit & (1 << i) が 1 になるタイミングがあるのだろうけど、これはどういう原理で起きているんだろう?

ここから Playground が便利な関係で Rust で説明します(笑)。

シフト演算

(1 << n) の正体はシフト演算です。

まずシフト演算の一般的な話ですが、2進数において、ビットの位置をずらす演算をシフト演算と呼びます。左にずらずと左シフト、右にずらすと右シフトです。ちなみにコンピュータは内部はすべて2進数で表現されているので、1ビット左にずらすと任意の数値の2倍を得られ、1ビット右にずらすと任意の数値の1/2倍を得られます。

次に 1 << n をすると何が得られるかという話ですが、これは 2n を得る計算をしています。1 から n ビット左にずらすと得られる値を取得しています。Rust のコードですが、実際に動かしてみてみましょう。

fn main() {
    let bit1 = 1 << 1;
    let bit2 = 1 << 2;
    let bit3 = 1 << 3;
    let bit4 = 1 << 4;
    let bit5 = 1 << 5;
    
    println!("{}", bit1);
    println!("{}", bit2);
    println!("{}", bit3);
    println!("{}", bit4);
    println!("{}", bit5);
}

この結果は、下記のとおりとなります。

2
4
8
16
32

21, 22, ..., 25 という結果が得られているとわかります。せっかくなのでバイナリでも取得してみましょう。わかりやすさのために1も標準出力します。

fn main() {
    let bit1 = 1 << 1;
    let bit2 = 1 << 2;
    let bit3 = 1 << 3;
    let bit4 = 1 << 4;
    let bit5 = 1 << 5;
    
    println!("{:#08b}", 1);
    println!("{:#08b}", bit1);
    println!("{:#08b}", bit2);
    println!("{:#08b}", bit3);
    println!("{:#08b}", bit4);
    println!("{:#08b}", bit5);
}

結果は、

0b000001
0b000010
0b000100
0b001000
0b010000
0b100000

となり、1 がシフトした箇所に立っていっている様子がわかります。美しい🥰

ビットフラグ判定の &

bit & (1 << i) の正体はビットフラグの判定です。

ビット演算には & 演算があります。論理積を取る演算です。それぞれのビットについて両方とも1の場合は1を、それ以外の場合は0を返す演算です。ビット演算における & は、特定のビットを取り出すときに使用します。

まず「論理積を取る」について理解します。次のコードをご覧ください。

fn main() {
    let bit = 5;
    println!("{:#08b}", bit);
    println!("{:#08b}", (1 << 2));
    println!("{:#08b}", bit & (1 << 2));
}

標準出力の結果は次のようになります。

0b000101
0b000100
0b000100

&論理積の取得であったと思い出すと、

   0b000101
*) 0b000100
------------

が行われています(0b は Rust の標準出力上の話で、bit を示す接頭辞なので計算上からはスルーできます)。この掛け算を、各桁ごとに計算すると、

   0b000101
*) 0b000100
------------
   0b000100

です。1 * 1 = 1, 0 * 0 = 0, 1 * 0 = 0 です。これがまず、& 演算子の正体です。

次に、bit & (1 << i) の意味を見ていきましょう。(1 << i) については、先ほどは1 を左に i分だけシフトさせると書きましたが、 (1 << i) は i ビット目に1を立てるという意味でもあります。たとえば、1 << 3 は、3ビット目に1を立てます。つまり、bit & (1 << i) は、ビット bit に i 番目のフラグが立っているかどうかを判定しています。たとえば、bit = 3 だったときに (1 << 0) との論理積を取ると、

fn main() {
    let bit = 3;
    println!("{:#08b}", bit);
    println!("{:#08b}", (1 << 0));
    println!("{:#08b}", bit & (1 << 0));
}
0b000011 // bit = 3
0b000001 // 1 << 0
0b000001 // bit & (1 << 0)

0b000001 は、10進数では 1 を示します。つまり、0番目のビットの論理積は10進数上は 1 になります。これを if 文に入れ込めば、そのまま真偽判定に利用できますね!bit 全探索ではこれを利用していたのでした。

困ったときの標準出力

先ほどの C++ のコードにもう一度戻りましょう。上記を踏まえた上で、結局どういう動作をしているのか改めて見直してみます。状態遷移を追いたくなったら、することはひとつですね。標準出力して確かめてみましょう。コードを次のように改変してみます。

#include <iostream>
#include <vector>

using namespace std;

int n;
int a[25];
int A;

int main() {
    cin >> n;
    for (int i = 0; i < n; i++) cin >> a[i];
    cin >> A;

    bool existence = false;

    for (int bit = 0; bit < (1 << n); bit++) {
        int sum = 0;
        for (int i = 0; i < n; i++) {
            if (bit & (1 << i)) {
+                cout << "i: " << i << ", bit: " << bit  << ", (1 << i): " << (1 << i) << endl;
                sum += a[i];
            }
        }

+        cout << "sum: " << sum << endl;

        if (sum == A) {
            existence = true;
        }
    }

    if (existence) cout << "Yes" << endl; else cout << "No" << endl;
}

まず、1つ目のループで謎だった for (int bit = 0; bit < (1 << n); bit++) についてですが、たとえば n = 3 を与えると、bit < 8 が出来上がります。

bit全探索の紹介時に、 n = 100 くらいにすると計算量が、という話を書きましたが 2100 = 1267650600228229401496703205376 の計算をするので相当大きなループになってしまいます。仮に n = 16 を入れたとしても、216 = 65536 なので、それなりのループ量になってきます。さらに、中でもう一度 n = 16 なり n = 100 のループを回すので、計算量が大変ですね。

楽ちんではあるけれど、そこまで効率のよくない探索の方法ですね。

次に謎だった if 文の条件分岐の動きを見てみましょう。今回は条件分岐の中に入った際に、i の値と bit の値と 1 << i をした結果を標準出力するようにしています。また、sum もついでに取得しています。動かしてみましょう。

$ ./total_sum
3
7 5 3
10
sum: 0 // bit = 0 の出力です
i: 0, bit: 1, (1 << i): 1
sum: 7
i: 1, bit: 2, (1 << i): 2
sum: 5
i: 0, bit: 3, (1 << i): 1
i: 1, bit: 3, (1 << i): 2
sum: 12
i: 2, bit: 4, (1 << i): 4
sum: 3
i: 0, bit: 5, (1 << i): 1
i: 2, bit: 5, (1 << i): 4
sum: 10
i: 1, bit: 6, (1 << i): 2
i: 2, bit: 6, (1 << i): 4
sum: 8
i: 0, bit: 7, (1 << i): 1
i: 1, bit: 7, (1 << i): 2
i: 2, bit: 7, (1 << i): 4
sum: 15
Yes

i の遷移を見ると、 n の部分集合 (今回は配列のインデックスに対応するので、{0, 1, 2} の部分集合) をきちんと取得できています。n = 3 だったとき、部分集合は {φ}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}, {0, 1, 2} ですね。バッチリ取得できています。あとはこれと、もともとの入力の [7, 5, 3] という配列とを対応させて足し上げ、合計値が10であるかどうかを確かめれば大丈夫です。

ビット演算の左シフトが、1インクリメントすると1つ左のフラグを1に変えていくことを利用して、部分集合を上手に取得できてしまうアルゴリズムを今回は学びました。ビット演算すごい!

参考記事

*1:ちなみにそういうときは DP を使用するようです。もしかすると、DP をそもそも使う問題なので、慣れてしまえば bit 全探索である必要はないのかもしれません。

*2:P = Power で冪の意味。