ABC157Eで学ぶセグ木

セグ木を学ぶチャンスだと思って、セグ木を使って解けるようにまでセグ木を学んでいきたいと思います。

この問題に歯が全く立ちませんでいた。

E - Simple String Queries

何故セグ木か

$S$の$l_q$文字目から$r_q$文字目まで (両端含む) から成る部分文字列に表れる文字が何種類あるかという質問。

愚直に実装して文字列Sを左から見ていくと$\mathcal{O}(N)$になると思います。

このまま実装するとQ回のクエリに答えるため、$\mathcal{O}(NQ)$で間に合いません。

セグ木を使うと区間に対するクエリに$\mathcal{O}(logN)$で答えられるようになります。

なのでセグ木を使うことで計算量を$\mathcal{O}(Qlog{N})$に抑えられます。

凄いですね!これは学ぶしかない!

構造

セグメント木は完全二分木になっています!

葉は子を持たず、葉以外のノードは必ず、2個のの子を持ちます。

nibunki

実装の際は葉に数列を入れます。データが2の冪乗出ない場合、一番後ろの値をコピーしておきます。

木全体でノードの数はN(葉の数)+N-1(葉を除いたノードの数)で$2N+1$個になります。

例えば、[1,2,3,4,5,6]っていうデータをセグメント木のノードに乗せる際、[1,2,3,4,5,6,6,6]として載せます。

こんな感じです。

nibunki

今回はabcdbbdという文字列について考えます。

葉に載せるデータは[a,b,c,d,b,b,d,d]となります。

ここで実際にノードにデータを入れてみたいのですが、今回はビットで管理します。

例えば一つのノードに対して26bit持ち、それぞれについて入ってるかどうかを管理することで、ビットの立っている数を求めることで答えを出せます。

aのビットが立っていることを{a}と表してこのように考えてみましょう。

ついでにこの辺りからコードを少しづつ書いてみます。

ll,REPなどのマクロ使いまくるので許して。

まずは、‘a’などの小文字をbitに直すメソッドを書きます。

1
ll charToBit(char c) {return 1<<(c-'a');}

でこれを使って、7~14に代入していきましょう。

せっかくなのでクラスを作ります。

1
2
3
4
5
6
7
8
class SegmentTree {
    vector<ll> v;
public:
    SegmentTree(int sz) {
        v.resize((2*sz)-1);
        REP(i,sz) cin >> v[sz-1+i];
    }
};

これで7~14への代入は出来ました。次は1~6ですね。

最終的にはこんな感じです。

segu

ざっとこんな感じのコードを書いてノードに値を入れていきます。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
ll charToBit(char c) {return 1<<(c-'a');}

class SegmentTree {
    int _sz;
    vector<ll> v;
    ll near_pow(ll sz) {
        ll k = 1;
        while(sz>k) k*=2;
        return k;
    }

public:
    SegmentTree(int sz) {
        _sz = near_pow(sz);
        v.resize((2*_sz)-1,0);
        string s; cin >> s;
        REP(i,sz) v[_sz-1+i] = charToBit(s[i]);
        FOR(i,sz,_sz) v[_sz-1+i] = v[_sz+sz-2]; //残りを最後の要素で埋める
        // 葉以外の要素を求める。
        REPR(i,_sz-2) v[i] = v[(2*i)+1] | v[(2*i)+2];
    }
    void print() {
        REP(i,v.size()) cout << i << ": " << lltoStr(v[i]) << endl;
    }
};

void solve() {
    int N;
    cin >> N;
    SegmentTree s(N);
    s.print();

}

入力に以下を入れる。

1
2
7
abcdbbd

出力はこんな感じに。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
0: abcd
1: abcd
2: bd
3: ab
4: cd
5: b
6: d
7: a
8: b
9: c
10: d
11: b
12: b
13: d
14: d

ノードを更新する。

最下位のノードを更新して、そこから影響ある部分を順に更新していきます。段数分の計算量がかかるため、$\mathcal{O}{(logN)}$になります。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
    void update(int key,ll val) {
        // 最下位のノードにアクセス
        key += _sz-1;
        // 最下位のノードを更新
        v[key] = val;
        // 上位ノードの更新
        while(key > 0) {
            key = (key-1)/2;
            v[key] = v[(2*key)+1] | v[(2*key)+2];
        }
    }

区間に対するクエリに答える。

再帰を用いて、上から見ていきます。

  • 今見ているノードのインデックス
  • そのインデックスの対象としている範囲

を使用して、欲しい区間を一部被覆していたら区間を二分割(子にアクセス)して、再度区間を確認します。全部被覆していたらそれを使います。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
    // 区間[a,b)の要素数を求める。 [l,r)は自分の今いるノードの区間、kは自分の今いるノードのインデックス。
    ll get(int a,int b,int l = 0,int r = -1,int k = 0) {
        // 初期化処理 最初は[l,r) = [0,N)
        if (r<0) r = _sz;

        // if(a==b) return v[_sz-1+a];

        // 欲しい区間と今いるノードの区間が交わらないなら答えに影響のない値を適当に返す
        if ( l >= b || r<= a ) return 0;

        //今いるノードの区間が全て欲しい区間の中だったらそれを答えに使う
        if ( a <= l && r <=  b) return v[k];

        // 今いるノードの区間の一部が欲しい場合は探索を進める。
        // 自分の今いるノードの区間を半分に分けて探索する。
        ll vl = get(a,b,l,(l+r)/2,2*k+1);
        ll vr = get(a,b,(l+r)/2,r,2*k+2);

        return vl | vr;
    }

157Eの解答

あとはこんな感じです。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
#include <bits/stdc++.h>
#define INF 1e9
using namespace std;

#define REPR(i,n) for(int i=(n); i >= 0; --i)
#define FOR(i, m, n) for(int i = (m); i < (n); ++i)
#define REP(i, n) for(int i=0, i##_len=(n); i<i##_len; ++i)
#define ALL(a)  (a).begin(),(a).end()

template<class T>bool chmin(T &a, const T &b) { if (b<a) { a=b; return true; } return false; }
template<class T>bool chmax(T &a, const T &b) { if (a<b) { a=b; return true; } return false; }
int gcd(int a,int b){return b?gcd(b,a%b):a;}
typedef long long ll;
ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
ll charToBit(char c) {return 1<<(c-'a');}

class SegmentTree {
    int _sz;
    vector<ll> v;
    ll near_pow(ll sz) {
        ll k = 1;
        while(sz>k) k*=2;
        return k;
    }

public:
    SegmentTree(int sz) {
        _sz = near_pow(sz);
        v.resize((2*_sz)-1,0);
        string s; cin >> s;
        REP(i,sz) v[_sz-1+i] = charToBit(s[i]);
        FOR(i,sz,_sz) v[_sz-1+i] = v[_sz+sz-2]; //残りを最後の要素で埋める
        // 葉以外の要素を求める。
        REPR(i,_sz-2) v[i] = v[(2*i)+1] | v[(2*i)+2];

    }
    string lltoStr(ll k) {
        string s;
        REP(i,26) if((1<<i) & k) s.push_back((char)i+'a');
        return s;
    }
    void print() {
        REP(i,v.size()) cerr << i << ": " << lltoStr(v[i]) << endl;
    }
    void update(int key,ll val) {
        // 最下位のノードにアクセス
        key += _sz-1;
        // 最下位のノードを更新
        v[key] = val;
        // 上位ノードの更新
        while(key > 0) {
            key = (key-1)/2;
            v[key] = v[(2*key)+1] | v[(2*key)+2];
        }
    }
    // 区間[a,b)の要素数を求める。 [l,r)は自分の今いるノードの区間、kは自分の今いるノードのインデックス。
    ll get(int a,int b,int l = 0,int r = -1,int k = 0) {
        // 初期化処理 最初は[l,r) = [0,N)
        if (r<0) r = _sz;

        // if(a==b) return v[_sz-1+a];

        // 欲しい区間と今いるノードの区間が交わらないなら答えに影響のない値を適当に返す
        if ( l >= b || r<= a ) return 0;

        //今いるノードの区間が全て欲しい区間の中だったらそれを答えに使う
        if ( a <= l && r <=  b) return v[k];

        // 今いるノードの区間の一部が欲しい場合は探索を進める。
        // 自分の今いるノードの区間を半分に分けて探索する。
        ll vl = get(a,b,l,(l+r)/2,2*k+1);
        ll vr = get(a,b,(l+r)/2,r,2*k+2);

        return vl | vr;
    }

};
ll count(ll k) {
    int cnt = 0;
    REP(i,26) if((1<<i) & k) cnt++;
    return cnt;
}
void solve() {
    int N;
    cin >> N;
    SegmentTree s(N);
    int Q;
    cin >> Q;
    REP(_,Q) {
        int typ;
        cin >> typ;
        if(typ == 1) {
            int i; char c;
            cin >> i >> c;
            s.update(i-1,charToBit(c));
        }
        if (typ == 2) {
            int l,r;
            cin >> l >> r;
            cout << count(s.get(l-1,r)) << endl;
        }
    }

}

int main() {
    solve();
    return 0;
}

提出 #10740498 - AtCoder Beginner Contest 157

参考サイト

Licensed under CC BY-NC-ND 4.0
Built with Hugo
テーマ StackJimmy によって設計されています。