Post

[BOJ 13537] 수열과 쿼리 1

[BOJ 13537] 수열과 쿼리 1

백준 로고

백준 문제 링크

🔍 문제 분석

  • 본 문제는 특정 구간이 주어졌을 때 구간 내에 k이상의 원소의 수를 구하는 문제이다. 만약, 특정 구간이 주어졌을 때 구간 내의 원소가 정렬되어 있다면 이분 탐색을 통해 O(logN)내에 k이상의 원소의 수를 구할 수 있다.
  • 세그먼트 트리의 각 노드가 해당 구간 내의 원소를 정렬한 배열을 나타낸다고 하면, 임의의 구간에 대해서 k이상의 원소의 수를 구하는 것은 간단하다. 따라서, 세그먼트 트리를 생성하는 방법에 대해서 고민해야 한다.
  • 세그먼트 트리상향식(Bottom-Up)으로 초기화되므로 정렬된 두 배열을 병합해서 새로운 정렬된 배열을 만드는 방법을 생각해야 한다. 이 방법은 병합 정렬(merge sort)를 통해서 잘 알려져있다. 자세한 내용은 코드를 통해서 확인한다.

💻 코드 구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <bits/stdc++.h>
#define fastio cin.tie(0)->ios::sync_with_stdio(0)

using namespace std;
using ll = long long;
using pi = pair<int, int>;
using pll = pair<ll, ll>;
using vi = vector<int>;
using vvi = vector<vector<int>>;
using vvll = vector<vector<ll>>;
using vll = vector<ll>;
using vpi = vector<pi>;

constexpr int MAX = 100'000;

int n, m, arr[MAX];
vi tree[4 * MAX];

vi merge(const vi& lhs, const vi& rhs) {
    vi ret;
    int lidx = 0, ridx = 0;
    while (lidx < lhs.size() || ridx < rhs.size()) {
        if (lidx == lhs.size()) ret.push_back(rhs[ridx++]);
        else if (ridx == rhs.size()) ret.push_back(lhs[lidx++]);
        else if (lhs[lidx] < rhs[ridx]) ret.push_back(lhs[lidx++]);
        else ret.push_back(rhs[ridx++]);
    }
    return ret;
}

int count(const vi& v, int k) {
    return distance(upper_bound(v.begin(), v.end(), k), v.end());
}

int l(int node) {
    return 2 * node;
}

int r(int node) {
    return 2 * node + 1;
}

vi init(int begin, int end, int node) {
    if (begin == end) return tree[node] = {arr[begin]};
    int mid = (begin + end) / 2;
    return tree[node] = merge(init(begin, mid, l(node)), init(mid + 1, end, r(node)));
}

int query(int begin, int end, int node, int left, int right, int k) {
    if (left <= begin && end <= right) return count(tree[node], k);
    else if (left > end || begin > right) return 0;
    int mid = (begin + end) / 2;
    return query(begin, mid, l(node), left, right, k) + query(mid + 1, end, r(node), left, right, k);
}

int main() {
    fastio;

    cin >> n;
    for (int i = 0; i < n; ++i) cin >> arr[i];
    init(0, n - 1, 1);
    cin >> m;
    for (int i = 0; i < m; ++i) {
        int left, right, k;
        cin >> left >> right >> k;
        --left; --right;
        cout << query(0, n - 1, 1, left, right, k) << "\n";
    }
}

📝 코드 설명

🔧 트러블 슈팅

📚 참고자료

This post is licensed under CC BY 4.0 by the author.