m1une's library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: Segment Tree
(data_structure/segtree/segtree.hpp)

Overview

A segment tree is a versatile data structure that allows for efficient range queries on an array. It can compute the result of any associative operation (defined by a monoid) over a given range [l, r) in logarithmic time. It also supports updating individual elements.

This implementation is generic and designed to be used with the monoid structures available in the library.

Methods

Depends on

Verified with

Code

#ifndef M1UNE_SEGTREE_HPP
#define M1UNE_SEGTREE_HPP 1

#include <algorithm>
#include <functional>
#include <type_traits>
#include <vector>

#include "monoid/monoid.hpp"
#include "utilities/bit_ceil.hpp"

namespace m1une {

template <Monoid M>
struct segment_tree {
    using T = typename M::value_type;

   private:
    int _n;
    int _size;
    std::vector<T> _data;

    void update(int k) {
        _data[k] = M::op(_data[2 * k], _data[2 * k + 1]);
    }

   public:
    segment_tree() : segment_tree(0) {}
    explicit segment_tree(int n) : segment_tree(std::vector<T>(n, M::id())) {}
    explicit segment_tree(const std::vector<T>& v) : _n(v.size()) {
        _size = bit_ceil((unsigned int)_n);
        _data.assign(2 * _size, M::id());
        for (int i = 0; i < _n; i++) {
            _data[_size + i] = v[i];
        }
        for (int i = _size - 1; i >= 1; i--) {
            update(i);
        }
    }

    // Set value at position p
    void set(int p, T x) {
        p += _size;
        _data[p] = x;
        for (int i = 1; p >> i >= 1; i++) {
            update(p >> i);
        }
    }

    // Get value at position p
    T get(int p) const {
        return _data[p + _size];
    }

    // Product of range [l, r)
    T prod(int l, int r) const {
        T sml = M::id(), smr = M::id();
        l += _size;
        r += _size;
        while (l < r) {
            if (l & 1) sml = M::op(sml, _data[l++]);
            if (r & 1) smr = M::op(_data[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return M::op(sml, smr);
    }

    // Product of the whole range
    T all_prod() const {
        return _data[1];
    }

    // Find max_right r such that f(prod([l, r))) is true
    int max_right(int l, auto f) const {
        static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
                      "f must be a callable that takes a Monoid::value_type and returns a boolean");
        if (l == _n) return _n;
        l += _size;
        T sm = M::id();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(M::op(sm, _data[l]))) {
                while (l < _size) {
                    l = (2 * l);
                    if (f(M::op(sm, _data[l]))) {
                        sm = M::op(sm, _data[l]);
                        l++;
                    }
                }
                return l - _size;
            }
            sm = M::op(sm, _data[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    // Find min_left l such that f(prod([l, r))) is true
    int min_left(int r, auto f) const {
        static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
                      "f must be a callable that takes a Monoid::value_type and returns a boolean");
        if (r == 0) return 0;
        r += _size;
        T sm = M::id();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(M::op(_data[r], sm))) {
                while (r < _size) {
                    r = (2 * r + 1);
                    if (f(M::op(_data[r], sm))) {
                        sm = M::op(_data[r], sm);
                        r--;
                    }
                }
                return r + 1 - _size;
            }
            sm = M::op(_data[r], sm);
        } while ((r & -r) != r);
        return 0;
    }
};

}  // namespace m1une

#endif  // M1UNE_SEGTREE_HPP
#line 1 "data_structure/segtree/segtree.hpp"



#include <algorithm>
#include <functional>
#include <type_traits>
#include <vector>

#line 1 "monoid/monoid.hpp"



#include <concepts>
#line 7 "monoid/monoid.hpp"

namespace m1une {

template <typename T, auto operation, auto identity, bool commutative>
struct monoid {
    static_assert(std::is_invocable_r_v<T, decltype(operation), T, T>, "operation must work as T(T, T)");
    static_assert(std::is_invocable_r_v<T, decltype(identity)>, "identity must work as T()");

    using value_type = T;
    static constexpr auto op = operation;
    static constexpr auto id = identity;
    static constexpr bool is_commutative = commutative;
};

template <typename T>
concept Monoid = requires(typename T::value_type v) {
    typename T::value_type;
    { T::op(v, v) } -> std::same_as<typename T::value_type>;
    { T::id() } -> std::same_as<typename T::value_type>;
    { T::is_commutative } -> std::convertible_to<bool>;
};

}  // namespace m1une


#line 1 "utilities/bit_ceil.hpp"



namespace m1une {
template <typename T>
constexpr T bit_ceil(T n) {
    if (n <= 1) return 1;
    T x = 1;
    while (x < n) x <<= 1;
    return x;
}
}  // namespace m1une


#line 11 "data_structure/segtree/segtree.hpp"

namespace m1une {

template <Monoid M>
struct segment_tree {
    using T = typename M::value_type;

   private:
    int _n;
    int _size;
    std::vector<T> _data;

    void update(int k) {
        _data[k] = M::op(_data[2 * k], _data[2 * k + 1]);
    }

   public:
    segment_tree() : segment_tree(0) {}
    explicit segment_tree(int n) : segment_tree(std::vector<T>(n, M::id())) {}
    explicit segment_tree(const std::vector<T>& v) : _n(v.size()) {
        _size = bit_ceil((unsigned int)_n);
        _data.assign(2 * _size, M::id());
        for (int i = 0; i < _n; i++) {
            _data[_size + i] = v[i];
        }
        for (int i = _size - 1; i >= 1; i--) {
            update(i);
        }
    }

    // Set value at position p
    void set(int p, T x) {
        p += _size;
        _data[p] = x;
        for (int i = 1; p >> i >= 1; i++) {
            update(p >> i);
        }
    }

    // Get value at position p
    T get(int p) const {
        return _data[p + _size];
    }

    // Product of range [l, r)
    T prod(int l, int r) const {
        T sml = M::id(), smr = M::id();
        l += _size;
        r += _size;
        while (l < r) {
            if (l & 1) sml = M::op(sml, _data[l++]);
            if (r & 1) smr = M::op(_data[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return M::op(sml, smr);
    }

    // Product of the whole range
    T all_prod() const {
        return _data[1];
    }

    // Find max_right r such that f(prod([l, r))) is true
    int max_right(int l, auto f) const {
        static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
                      "f must be a callable that takes a Monoid::value_type and returns a boolean");
        if (l == _n) return _n;
        l += _size;
        T sm = M::id();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(M::op(sm, _data[l]))) {
                while (l < _size) {
                    l = (2 * l);
                    if (f(M::op(sm, _data[l]))) {
                        sm = M::op(sm, _data[l]);
                        l++;
                    }
                }
                return l - _size;
            }
            sm = M::op(sm, _data[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    // Find min_left l such that f(prod([l, r))) is true
    int min_left(int r, auto f) const {
        static_assert(std::is_convertible_v<std::invoke_result_t<decltype(f), T>, bool>,
                      "f must be a callable that takes a Monoid::value_type and returns a boolean");
        if (r == 0) return 0;
        r += _size;
        T sm = M::id();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(M::op(_data[r], sm))) {
                while (r < _size) {
                    r = (2 * r + 1);
                    if (f(M::op(_data[r], sm))) {
                        sm = M::op(_data[r], sm);
                        r--;
                    }
                }
                return r + 1 - _size;
            }
            sm = M::op(_data[r], sm);
        } while ((r & -r) != r);
        return 0;
    }
};

}  // namespace m1une
Back to top page