Segment Tree
(data_structure/segtree/segtree.hpp)
- View this file on GitHub
- Last update: 2025-10-01 15:41:05+09:00
- Include:
#include "data_structure/segtree/segtree.hpp"
Overview
A segment tree is a versatile data structure that allows for efficient range queries on an array. It can compute the result of any associative operation (defined by a monoid) over a given range [l, r) in logarithmic time. It also supports updating individual elements.
This implementation is generic and designed to be used with the monoid structures available in the library.
Methods
-
segment_tree()Constructs an empty segment tree.
Time complexity: $O(1)$.
-
explicit segment_tree(int n)Constructs a segment tree of size
n, initialized with the identity element of the monoid.Time complexity: $O(N)$.
-
explicit segment_tree(const std::vector<T>& v)Constructs a segment tree from an initial vector
v.Time complexity: $O(N)$.
-
void set(int p, T x)Sets the value at position
p(0-indexed) tox.Time complexity: $O(\log N)$.
-
T get(int p) constReturns the value at position
p(0-indexed).Time complexity: $O(1)$.
-
T prod(int l, int r) constCalculates the result of the monoid operation on the range
[l, r)(0-indexed, half-open interval).Time complexity: $O(\log N)$.
-
T all_prod() constCalculates the result of the monoid operation on the entire range
[0, n).Time complexity: $O(1)$.
-
int max_right(int l, auto f) constFinds the largest
rsuch thatl <= r <= nand the predicatef(prod(l, r))is true.fmust be a function that takes a monoid value and returns a boolean.Time complexity: $O(\log N)$.
-
int min_left(int r, auto f) constFinds the smallest
lsuch that0 <= l <= rand the predicatef(prod(l, r))is true.fmust be a function that takes a monoid value and returns a boolean.Time complexity: $O(\log N)$.
Depends on
Verified with
Code
#ifndef M1UNE_SEGTREE_HPP
#define M1UNE_SEGTREE_HPP 1
#include <algorithm>
#include <functional>
#include <type_traits>
#include <vector>
#include "monoid/monoid.hpp"
#include "utilities/bit_ceil.hpp"
namespace m1une {
template <Monoid M>
struct segment_tree {
using T = typename M::value_type;
private:
int _n;
int _size;
std::vector<T> _data;
void update(int k) {
_data[k] = M::op(_data[2 * k], _data[2 * k + 1]);
}
public:
segment_tree() : segment_tree(0) {}
explicit segment_tree(int n) : segment_tree(std::vector<T>(n, M::id())) {}
explicit segment_tree(const std::vector<T>& v) : _n(v.size()) {
_size = bit_ceil((unsigned int)_n);
_data.assign(2 * _size, M::id());
for (int i = 0; i < _n; i++) {
_data[_size + i] = v[i];
}
for (int i = _size - 1; i >= 1; i--) {
update(i);
}
}
// Set value at position p
void set(int p, T x) {
p += _size;
_data[p] = x;
for (int i = 1; p >> i >= 1; i++) {
update(p >> i);
}
}
// Get value at position p
T get(int p) const {
return _data[p + _size];
}
// Product of range [l, r)
T prod(int l, int r) const {
T sml = M::id(), smr = M::id();
l += _size;
r += _size;
while (l < r) {
if (l & 1) sml = M::op(sml, _data[l++]);
if (r & 1) smr = M::op(_data[--r], smr);
l >>= 1;
r >>= 1;
}
return M::op(sml, smr);
}
// Product of the whole range
T all_prod() const {
return _data[1];
}
// Find max_right r such that f(prod([l, r))) is true
int max_right(int l, auto f) const {
static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
"f must be a callable that takes a Monoid::value_type and returns a boolean");
if (l == _n) return _n;
l += _size;
T sm = M::id();
do {
while (l % 2 == 0) l >>= 1;
if (!f(M::op(sm, _data[l]))) {
while (l < _size) {
l = (2 * l);
if (f(M::op(sm, _data[l]))) {
sm = M::op(sm, _data[l]);
l++;
}
}
return l - _size;
}
sm = M::op(sm, _data[l]);
l++;
} while ((l & -l) != l);
return _n;
}
// Find min_left l such that f(prod([l, r))) is true
int min_left(int r, auto f) const {
static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
"f must be a callable that takes a Monoid::value_type and returns a boolean");
if (r == 0) return 0;
r += _size;
T sm = M::id();
do {
r--;
while (r > 1 && (r % 2)) r >>= 1;
if (!f(M::op(_data[r], sm))) {
while (r < _size) {
r = (2 * r + 1);
if (f(M::op(_data[r], sm))) {
sm = M::op(_data[r], sm);
r--;
}
}
return r + 1 - _size;
}
sm = M::op(_data[r], sm);
} while ((r & -r) != r);
return 0;
}
};
} // namespace m1une
#endif // M1UNE_SEGTREE_HPP#line 1 "data_structure/segtree/segtree.hpp"
#include <algorithm>
#include <functional>
#include <type_traits>
#include <vector>
#line 1 "monoid/monoid.hpp"
#include <concepts>
#line 7 "monoid/monoid.hpp"
namespace m1une {
template <typename T, auto operation, auto identity, bool commutative>
struct monoid {
static_assert(std::is_invocable_r_v<T, decltype(operation), T, T>, "operation must work as T(T, T)");
static_assert(std::is_invocable_r_v<T, decltype(identity)>, "identity must work as T()");
using value_type = T;
static constexpr auto op = operation;
static constexpr auto id = identity;
static constexpr bool is_commutative = commutative;
};
template <typename T>
concept Monoid = requires(typename T::value_type v) {
typename T::value_type;
{ T::op(v, v) } -> std::same_as<typename T::value_type>;
{ T::id() } -> std::same_as<typename T::value_type>;
{ T::is_commutative } -> std::convertible_to<bool>;
};
} // namespace m1une
#line 1 "utilities/bit_ceil.hpp"
namespace m1une {
template <typename T>
constexpr T bit_ceil(T n) {
if (n <= 1) return 1;
T x = 1;
while (x < n) x <<= 1;
return x;
}
} // namespace m1une
#line 11 "data_structure/segtree/segtree.hpp"
namespace m1une {
template <Monoid M>
struct segment_tree {
using T = typename M::value_type;
private:
int _n;
int _size;
std::vector<T> _data;
void update(int k) {
_data[k] = M::op(_data[2 * k], _data[2 * k + 1]);
}
public:
segment_tree() : segment_tree(0) {}
explicit segment_tree(int n) : segment_tree(std::vector<T>(n, M::id())) {}
explicit segment_tree(const std::vector<T>& v) : _n(v.size()) {
_size = bit_ceil((unsigned int)_n);
_data.assign(2 * _size, M::id());
for (int i = 0; i < _n; i++) {
_data[_size + i] = v[i];
}
for (int i = _size - 1; i >= 1; i--) {
update(i);
}
}
// Set value at position p
void set(int p, T x) {
p += _size;
_data[p] = x;
for (int i = 1; p >> i >= 1; i++) {
update(p >> i);
}
}
// Get value at position p
T get(int p) const {
return _data[p + _size];
}
// Product of range [l, r)
T prod(int l, int r) const {
T sml = M::id(), smr = M::id();
l += _size;
r += _size;
while (l < r) {
if (l & 1) sml = M::op(sml, _data[l++]);
if (r & 1) smr = M::op(_data[--r], smr);
l >>= 1;
r >>= 1;
}
return M::op(sml, smr);
}
// Product of the whole range
T all_prod() const {
return _data[1];
}
// Find max_right r such that f(prod([l, r))) is true
int max_right(int l, auto f) const {
static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
"f must be a callable that takes a Monoid::value_type and returns a boolean");
if (l == _n) return _n;
l += _size;
T sm = M::id();
do {
while (l % 2 == 0) l >>= 1;
if (!f(M::op(sm, _data[l]))) {
while (l < _size) {
l = (2 * l);
if (f(M::op(sm, _data[l]))) {
sm = M::op(sm, _data[l]);
l++;
}
}
return l - _size;
}
sm = M::op(sm, _data[l]);
l++;
} while ((l & -l) != l);
return _n;
}
// Find min_left l such that f(prod([l, r))) is true
int min_left(int r, auto f) const {
static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
"f must be a callable that takes a Monoid::value_type and returns a boolean");
if (r == 0) return 0;
r += _size;
T sm = M::id();
do {
r--;
while (r > 1 && (r % 2)) r >>= 1;
if (!f(M::op(_data[r], sm))) {
while (r < _size) {
r = (2 * r + 1);
if (f(M::op(_data[r], sm))) {
sm = M::op(_data[r], sm);
r--;
}
}
return r + 1 - _size;
}
sm = M::op(_data[r], sm);
} while ((r & -r) != r);
return 0;
}
};
} // namespace m1une