RangeBST

扱う座標数を\(N\)とすると,RangeBSTの各操作と時間計算量は以下.

  • 構築: 処理なし
  • 指定された1次元座標に値(モノイドを成す集合の要素)の設置: \(O(\log N)\)
  • 指定された1次元座標の値の更新: \(O(\log N)\)
  • 任意の座標区間の"和"(モノイドにおける2項演算) : \(O(\log N)\)

※このRangeBSTという名前は,筆者が勝手につけたもので,このデータ構造には特別名前はなく,中身はBST,つまり平衡2分探索木である.平衡2分探索木に機能を追加することで能力を拡張する例は多々あり,これはその1つの例である.

セグメント木と同じ?と思った方もいるかもしれないが,実は少し違う.セグメント木は,配列に対する区間クエリに対応したもの,つまり,言い換えると座標に制限(0~配列のサイズ)がある.RangeBSTは,データを,座標の集合としてもつため,座標の値自体に制限がなく,座圧+セグ木で解く必要がある問題も,RangeBSTなら座圧せずに解くことができる(データ構造内部でも座圧操作をしない!!)ため,実装が楽になり,実行速度も速くなることがある.

説明

配列に対する区間クエリではないため,最初にサイズを与えたりはしない.使用するメモリ量は,追加した座標の数に比例して動的に増える(座標に重複がある場合は増えない).

例えば,座標10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150にそれぞれ何かしらの値を設置したとする.このとき,平衡2分探索木でこの集合を持つ形は複数あるが,仮に1番綺麗な形として下のようになったとしよう.

%%{init: {"flowchart" : { "curve" : "basis" } } }%%
graph TD
80((80)) --- 40((40))
80 --- 120((120))
40 --- 20((20))
40 --- 60((60))
120 --- 100((100))
120 --- 140((140))
20 --- 10((10))
20 --- 30((30))
60 --- 50((50))
60 --- 70((70))
100 --- 90((90))
100 --- 110((110))
140 --- 130((130))
140 --- 150((150))

ではここで,座標\(25 \sim 115\)の累積"和"を求めよと言われたら,どこのノードの値の"和"を求めればいいかというと,下の赤い部分である.

%%{init: {"flowchart" : { "curve" : "basis" } } }%%
graph TD
80((80)):::p --- 40((40)):::p
80 --- 120((120))
40 --- 20((20))
40 --- 60((60)):::p
120 --- 100((100)):::p
120 --- 140((140))
20 --- 10((10))
20 --- 30((30)):::p
60 --- 50((50)):::p
60 --- 70((70)):::p
100 --- 90((90)):::p
100 --- 110((110)):::p
140 --- 130((130))
140 --- 150((150))

classDef p stroke-width:4px,stroke:orangered;

もちろんそのままやっていてはノード数分の計算量がかかるが,上の赤く塗られたノードを見ると,平衡2分探索木のおかげである程度まとまった位置に求めたいノードが集まっていることがわかる.

じゃあそれぞれのノードが自分を根とする部分木の"和"を持っていたらどうなるかと考えると,見るノードは下の色がついたノードだけでよくなることがわかる.赤いノードはそのノードの値が必要であることを示しており,青いノードはそのノードを根とする部分木の累積"和"が必要であることを示している.

%%{init: {"flowchart" : { "curve" : "basis" } } }%%
graph TD
80((80)):::v --- 40((40)):::v
80 --- 120((120))
40 --- 20((20))
40 --- 60((60)):::p
120 --- 100((100)):::p
120 --- 140((140))
20 --- 10((10))
20 --- 30((30)):::p
60 --- 50((50))
60 --- 70((70))
100 --- 90((90))
100 --- 110((110))
140 --- 130((130))
140 --- 150((150))

classDef v stroke-width:4px,stroke:orangered;
classDef p stroke-width:4px,stroke:slateblue;

とても少なくなった.実は,この工夫だけで,任意区間の累積"和"を求めるのに必要なノード数が\(O(\log N)\)個に抑えられるのである.その理由を説明する.

これは言葉で説明するより図と疑似コードを見た方がわかりやすい.

よりちゃんとした書き方として,累積"和"を取得したい区間を \([25, 115)\)と書く.\([\)は閉区間を意味し,\()\)は開区間を意味する.まず,\(25\)と\(115\)を木上で2分探索(lower_bound)すると,下の緑のノードにたどり着く.また,この2つのノードのLCAをピンク色で示す.

%%{init: {"flowchart" : { "curve" : "basis" } } }%%
graph TD
80((80)):::lca --- 40((40))
80 --- 120((120)):::lb
40 --- 20((20))
40 --- 60((60))
120 --- 100((100))
120 --- 140((140))
20 --- 10((10))
20 --- 30((30)):::lb
60 --- 50((50))
60 --- 70((70))
100 --- 90((90))
100 --- 110((110))
140 --- 130((130))
140 --- 150((150))

classDef v stroke-width:4px,stroke:orangered;
classDef p stroke-width:4px,stroke:slateblue;
classDef lb stroke-width:4px,stroke:green;
classDef lca stroke-width:4px,stroke:deeppink;

左のノードは,LCAまで上りながら,右の部分木がある場合だけその累積"和"を累積すればよく,右のノードも同様に,LCAにまで上りながら,左の部分木がある場合だけその累積"和"を累積すればよい.わかりづらいと思うので,先ほどの図と重ねてみるとわかりやすい.  

%%{init: {"flowchart" : { "curve" : "basis" } } }%%
graph TD
80((80)):::lca --- 40((40)):::v
80 --- 120((120)):::lb
40 --- 20((20))
40 --- 60((60)):::p
120 --- 100((100)):::p
120 --- 140((140))
20 --- 10((10))
20 --- 30((30)):::lb
60 --- 50((50))
60 --- 70((70))
100 --- 90((90))
100 --- 110((110))
140 --- 130((130))
140 --- 150((150))

classDef v stroke-width:4px,stroke:orangered;
classDef p stroke-width:4px,stroke:slateblue;
classDef lb stroke-width:4px,stroke:green;
classDef lca stroke-width:4px,stroke:deeppink;

疑似コードを示すとこのようになる.

fn prod(xl: i64, xr: i64) -> S {
  let l: *Node = lower_bound(xl);
  let r: *Node = lower_bound(xr);
  let lca: *Node = get_lca(l, r);
  let lprod: S = e();
  { // 左からLCAまで上りながら右側の部分木を累積
    for (bool f = true; l != lca;) {
      if (f) lprod = op(lprod, l.r.prod_subtree);
      f = l.is_left_child();
      if (f && l.p != lca) lprod = op(lprod, l.p.v);
      l = l->p;
    }
  }
  let rprod: S = e();
  { // 右からLCAまで上りながら左側の部分木を累積
    for (bool f = true; r != lca;) {
      if (f) rprod = op(r.l.prod_subtree, rprod);
      f = r.is_right_child();
      if (f && r.p != lca) rprod = op(r.p.v, rprod);
      r = r->p;
    }
  }
  return op(op(lprod, lca.v), rprod);
}

よって,区間クエリの処理は,まず木上でlower_boundし,その次に2つのノードからLCAまで上りながら\(O(1)\)の演算を各ステップで行うだけである.したがって,平衡2分探索木の高さが\(\log N\)であることから,計算量は\(O(\log N)\)となる.

最後に,更新クエリや挿入クエリについて説明する.挿入/更新を\(O(\log N)\)で行えることは平衡2分探索木の章を確認して欲しい.挿入/更新後,ノードの累積"和"の値に影響が出るのは該当ノードの祖先ノード高々\(O(\log N)\)個であるため,それらの値を更新していけばよい.計算量は\(O(\log N) + O(\log N) = O(\log N)\)となる.

実装におけるその他の注意事項

モノイドを扱っているため,例のごとく演算の向きを間違えないようにする.

以上!

コード

平衡2分探索木にSplay木を用いた実装.

template <class S, S (*op)(S, S), S (*e)()> struct Node {
  Node<S, op, e> *l, *r, *p;
  i64 pt;
  S v, prod_st;
  explicit Node(i64 pt_, S v_)
    : l(nullptr), r(nullptr), p(nullptr), pt(pt_), v(v_), prod_st(v_) {}
  int state() {
    if (p && p->l == this) return -1;
    if (p && p->r == this) return 1;
    return 0;
  }
  S get_lprod() {
    if (!l) return e();
    return l->prod_st;
  }
  S get_rprod() {
    if (!r) return e();
    return r->prod_st;
  }
  void update() {
    prod_st = op(op(get_lprod(), v), get_rprod());
  }
  void rotate() {
    Node<S, op, e> *par = p;
    Node<S, op, e> *mid;
    if (p->l == this) {
      mid = r; r = par;
      par->l = mid;
    } else {
      mid = l; l = par;
      par->r = mid;
    }
    if (mid) mid->p = par;
    p = par->p; par->p = this;
    if (p && p->l == par) p->l = this;
    if (p && p->r == par) p->r = this;
    par->update(); update();
  }
  void splay() {
    while(state()) {
      int st = state() * p->state();
      if (st == 0) {
        rotate();
      } else if (st == 1) {
        p->rotate();
        rotate();
      } else {
        rotate();
        rotate();
      }
    }
  }
};

template <class S, S (*op)(S, S), S (*e)()> struct RangeBST {
private:
  using NC = Node<S, op, e>;
  NC *root, *min_, *max_;
  void splay(NC *node) { node->splay(), root = node; }
  NC* bound(i64 x, bool lower) {
    NC *valid = nullptr, *left = root, *right = nullptr;
    while (left) {
      valid = left;
      if ((lower && !(x > left->pt)) || (!lower && (x < left->pt))) {
        right = left;
        left = left->l;
      } else left = left->r;
    }
    if (!right && valid) splay(valid);
    return right;
  }
  void set(i64 x, S val, bool add) {
    NC *nn = new NC(x, val);
    // if no nodes in tree
    if (!root) {
      min_ = nn, max_ = nn, root = nn; return;
    } if (min_->pt > x) { // if x become min key in tree
      min_->l = nn, nn->p = min_, min_ = nn;
      splay(nn); return;
    } if (max_->pt < x) { // if x become max key in tree
      max_->r = nn, nn->p = max_, max_ = nn;
      splay(nn); return;
    }
    NC *node = bound(x, true); // assert node is not null
    if (node->pt == x) { // if tree already has key x
      if (add) node->v = op(node->v, val);
      else node->v = val;
      node->update(); splay(node); delete nn; return;
    }
    // now node is first node whose key is larger than x
    nn->l = node->l; node->l = nn;
    nn->p = node; if (nn->l) nn->l->p = nn;
    nn->update(); splay(nn);
  }
public:
  RangeBST() : root(nullptr), min_(nullptr), max_(nullptr) {}
  NC* lower_bound(i64 x) {
    NC *ret = bound(x, true);
    if (ret) splay(ret);
    return ret;
  }
  NC* upper_bound(i64 x) {
    NC *ret = bound(x, false);
    if (ret) splay(ret);
    return ret;
  }
  S get(i64 x) {
    NC *ret = lower_bound(x);
    if (!ret || ret->pt != x) return e();
    return ret->v;
  }
  void set(i64 x, S val) { set(x, val, false); }
  void add(i64 x, S val) { set(x, val, true); }
  S prod(i64 xl, i64 xr) {
    assert(xl <= xr);
    if (!root || xl > max_->pt || xr <= min_->pt) return e();
    if (xl <= min_->pt && xr > max_->pt) return root->prod_st;
    if (xl <= min_->pt) return lower_bound(xr)->get_lprod();
    lower_bound(xl); // now xl is root
    if (xr > max_->pt) return op(root->v, root->get_rprod());
    NC *right = bound(xr, true);
    NC *tmp = right;
    S ret = e();
    for (bool f = true; tmp != root;) {
      if (f) ret = op(tmp->get_lprod(), ret);
      f = tmp->state() == 1;
      if (f) ret = op(tmp->p->v, ret);
      tmp = tmp->p;
    }
    if (right) splay(right);
    return ret;
  }
};