m1une's library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: verify/unit_test/persistent_treap.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/range_kth_smallest"

#include "data_structure/bst/persistent_treap.hpp"

#include <algorithm>
#include <iostream>
#include <vector>

// Fast I/O
void fast_io() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);
}

int main() {
    fast_io();
    int N, Q;
    std::cin >> N >> Q;
    std::vector<int> a(N);
    std::vector<int> distinct_elements;

    for (int i = 0; i < N; ++i) {
        std::cin >> a[i];
        distinct_elements.push_back(a[i]);
    }

    // Coordinate Compression
    std::sort(distinct_elements.begin(), distinct_elements.end());
    distinct_elements.erase(std::unique(distinct_elements.begin(), distinct_elements.end()), distinct_elements.end());

    auto get_compressed_rank = [&](int val) {
        return std::lower_bound(distinct_elements.begin(), distinct_elements.end(), val) - distinct_elements.begin();
    };

    // Build a persistent treap for each prefix of the array
    std::vector<m1une::persistent_treap<int>> versions(N + 1);
    for (int i = 0; i < N; ++i) {
        versions[i + 1] = versions[i].insert(get_compressed_rank(a[i]));
    }

    for (int q = 0; q < Q; ++q) {
        int l, r, k;
        std::cin >> l >> r >> k;

        // Meguru-style Binary Search
        // We are looking for the smallest rank 'ok' such that the number of elements
        // in a[l..r-1] with rank <= 'ok' is strictly greater than k.
        int ng = -1;                            // 'ng' is a rank that is always "not good enough"
        int ok = distinct_elements.size() - 1;  // 'ok' is a rank that is "good enough"

        while (std::abs(ok - ng) > 1) {
            int mid = ng + (ok - ng) / 2;

            // Count elements in the range a[l..r-1] with a compressed rank <= mid
            int count_le = versions[r].order_of_key(mid + 1) - versions[l].order_of_key(mid + 1);

            if (count_le > k) {
                // mid is a possible answer, try for a smaller one
                ok = mid;
            } else {
                // mid is not the answer, we need a larger rank
                ng = mid;
            }
        }

        // The answer is the original value corresponding to the 'ok' rank
        std::cout << distinct_elements[ok] << "\n";
    }

    return 0;
}
#line 1 "verify/unit_test/persistent_treap.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/range_kth_smallest"

#line 1 "data_structure/bst/persistent_treap.hpp"



#include <algorithm>
#include <ctime>
#include <iostream>
#include <memory>
#include <optional>
#include <random>

namespace m1une {

template <typename T>
struct persistent_treap {
   private:
    struct node {
        T _key;
        int _priority;
        std::shared_ptr<node> _l, _r;
        int _count;

        node(T key) : _key(key), _priority(rand()), _l(nullptr), _r(nullptr), _count(1) {}
    };

    std::shared_ptr<node> _root;

    int count(std::shared_ptr<node> t) {
        return t ? t->_count : 0;
    }

    void update_count(std::shared_ptr<node> t) {
        if (t) {
            t->_count = 1 + count(t->_l) + count(t->_r);
        }
    }

    void split(std::shared_ptr<node> t, T key, std::shared_ptr<node>& l, std::shared_ptr<node>& r) {
        if (!t) {
            l = r = nullptr;
            return;
        }
        if (key < t->_key) {
            auto new_node = std::make_shared<node>(*t);
            split(new_node->_l, key, l, new_node->_l);
            r = new_node;
            update_count(r);
        } else {
            auto new_node = std::make_shared<node>(*t);
            split(new_node->_r, key, new_node->_r, r);
            l = new_node;
            update_count(l);
        }
    }

    std::shared_ptr<node> merge(std::shared_ptr<node> l, std::shared_ptr<node> r) {
        if (!l || !r) return l ? l : r;
        if (l->_priority > r->_priority) {
            auto new_node = std::make_shared<node>(*l);
            new_node->_r = merge(new_node->_r, r);
            update_count(new_node);
            return new_node;
        } else {
            auto new_node = std::make_shared<node>(*r);
            new_node->_l = merge(l, new_node->_l);
            update_count(new_node);
            return new_node;
        }
    }

    std::shared_ptr<node> insert_impl(std::shared_ptr<node> t, std::shared_ptr<node> item) {
        if (!t) return item;
        if (item->_priority > t->_priority) {
            split(t, item->_key, item->_l, item->_r);
            update_count(item);
            return item;
        }
        auto new_node = std::make_shared<node>(*t);
        if (item->_key < new_node->_key) {
            new_node->_l = insert_impl(new_node->_l, item);
        } else {
            new_node->_r = insert_impl(new_node->_r, item);
        }
        update_count(new_node);
        return new_node;
    }

    std::shared_ptr<node> erase_impl(std::shared_ptr<node> t, T key) {
        if (!t) return nullptr;
        if (t->_key == key) return merge(t->_l, t->_r);
        auto new_node = std::make_shared<node>(*t);
        if (key < new_node->_key) {
            new_node->_l = erase_impl(new_node->_l, key);
        } else {
            new_node->_r = erase_impl(new_node->_r, key);
        }
        update_count(new_node);
        return new_node;
    }

    T find_by_order_impl(std::shared_ptr<node> t, int k) {
        if (!t) return T();
        int left_count = count(t->_l);
        if (k < left_count) return find_by_order_impl(t->_l, k);
        if (k == left_count) return t->_key;
        return find_by_order_impl(t->_r, k - left_count - 1);
    }

    int order_of_key_impl(std::shared_ptr<node> t, T key) {
        if (!t) return 0;
        if (key <= t->_key) return order_of_key_impl(t->_l, key);
        return count(t->_l) + 1 + order_of_key_impl(t->_r, key);
    }

    std::optional<T> lower_bound_impl(std::shared_ptr<node> t, T key) {
        if (!t) return std::nullopt;
        if (key <= t->_key) {
            auto res = lower_bound_impl(t->_l, key);
            return res.has_value() ? res : t->_key;
        }
        return lower_bound_impl(t->_r, key);
    }

    std::optional<T> upper_bound_impl(std::shared_ptr<node> t, T key) {
        if (!t) return std::nullopt;
        if (key < t->_key) {
            auto res = upper_bound_impl(t->_l, key);
            return res.has_value() ? res : t->_key;
        }
        return upper_bound_impl(t->_r, key);
    }

   public:
    persistent_treap() : _root(nullptr) {
        srand(time(NULL));
    }

    persistent_treap(std::shared_ptr<node> root) : _root(root) {}

    persistent_treap insert(T key) {
        return persistent_treap(insert_impl(_root, std::make_shared<node>(key)));
    }

    persistent_treap erase(T key) {
        return persistent_treap(erase_impl(_root, key));
    }

    T find_by_order(int k) {
        return find_by_order_impl(_root, k);
    }

    int order_of_key(T key) {
        return order_of_key_impl(_root, key);
    }

    std::optional<T> lower_bound(T key) {
        return lower_bound_impl(_root, key);
    }

    std::optional<T> upper_bound(T key) {
        return upper_bound_impl(_root, key);
    }

    int size() {
        return count(_root);
    }
};

}  // namespace m1une


#line 4 "verify/unit_test/persistent_treap.test.cpp"

#line 7 "verify/unit_test/persistent_treap.test.cpp"
#include <vector>

// Fast I/O
void fast_io() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);
}

int main() {
    fast_io();
    int N, Q;
    std::cin >> N >> Q;
    std::vector<int> a(N);
    std::vector<int> distinct_elements;

    for (int i = 0; i < N; ++i) {
        std::cin >> a[i];
        distinct_elements.push_back(a[i]);
    }

    // Coordinate Compression
    std::sort(distinct_elements.begin(), distinct_elements.end());
    distinct_elements.erase(std::unique(distinct_elements.begin(), distinct_elements.end()), distinct_elements.end());

    auto get_compressed_rank = [&](int val) {
        return std::lower_bound(distinct_elements.begin(), distinct_elements.end(), val) - distinct_elements.begin();
    };

    // Build a persistent treap for each prefix of the array
    std::vector<m1une::persistent_treap<int>> versions(N + 1);
    for (int i = 0; i < N; ++i) {
        versions[i + 1] = versions[i].insert(get_compressed_rank(a[i]));
    }

    for (int q = 0; q < Q; ++q) {
        int l, r, k;
        std::cin >> l >> r >> k;

        // Meguru-style Binary Search
        // We are looking for the smallest rank 'ok' such that the number of elements
        // in a[l..r-1] with rank <= 'ok' is strictly greater than k.
        int ng = -1;                            // 'ng' is a rank that is always "not good enough"
        int ok = distinct_elements.size() - 1;  // 'ok' is a rank that is "good enough"

        while (std::abs(ok - ng) > 1) {
            int mid = ng + (ok - ng) / 2;

            // Count elements in the range a[l..r-1] with a compressed rank <= mid
            int count_le = versions[r].order_of_key(mid + 1) - versions[l].order_of_key(mid + 1);

            if (count_le > k) {
                // mid is a possible answer, try for a smaller one
                ok = mid;
            } else {
                // mid is not the answer, we need a larger rank
                ng = mid;
            }
        }

        // The answer is the original value corresponding to the 'ok' rank
        std::cout << distinct_elements[ok] << "\n";
    }

    return 0;
}
Back to top page