Segment tree를 활용하는 법에 익숙해져보자!

문제

11505.cpp

문제 풀이

이번 문제는 segment tree의 기본였던 구간 합 구하기 문제에서 조금 활용하여서 최솟값을 구하는 문제이다. 만약 Segment tree에 대해서 잘 모르겠다면 아래 링크의 문제를 먼저 풀어보거나 풀이과정을 보고 오는 것을 추천한다.

백준 구간합 구하기 2042: https://koreaygj.github.io/algorithm/bj2042/

구간 곱 구하기 문제는 다른 Segment tree문제와 비슷하게 트리를 구성하지만 그 트리의 내부가 곱셈으로 이루어져 있다는 점이 달랐다. 또한 가장 필자가 많이 틀렸던 부분은 요소를 바꾸고 난 후에 해줘야 하는 update함수에서 실수가 생기기 쉬웠다.

update 함수

void update(int start, int end, int node, int index, long long int change)
{
    if(index < start || index > end)
        return;
    if(start == end)
    {
        tree[node] = change;
        return;
    }
    int mid = (start + end) / 2;
    update(start, mid, node * 2, index, change);
    update(mid + 1, end, node * 2 + 1, index, change);
    tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % div_num;
}

처음에 드는 생각은 구간합을 갱신시켜 주는 것처럼 모든 트리를 나누면 되지 않을까 라는 생각을 했는데 그것보다는 트리의 최하단에 있는 수를 갱신 시킨 이후에 그 수들을 곱한것으로 트리를 최신화 시키는 방식으로 해야 한다. 왜냐하면 구간합을 갱신시켜주면서 mod값이 변경되는 경우가 생기기 때문이다.


전체코드 자세히 보기
#include <iostream>
#include <algorithm>
#include <vector>
#define div_num 1000000007
using namespace std;
vector<long long int> tree(1000005 * 4, 0);
vector<int> arr(1000005, 0);
long long int init(int start, int end, int node)
{
    if(start == end)
        return tree[node] = arr[start] % div_num;
    int mid = (start + end) / 2;
    return tree[node] = (init(start, mid, node * 2)* init(mid + 1, end, node * 2 + 1)) % div_num;
}
long long int multiplex(int start, int end, int node, int left, int right)
{
    if(left > end || right < start)
        return 1;
    if(left <= start && end <= right)
        return tree[node] % div_num;
    int mid = (start + end) / 2;
    return (multiplex(start, mid, node * 2, left, right) * multiplex(mid + 1, end, node * 2 + 1, left, right)) % div_num;
}
void update(int start, int end, int node, int index, long long int change)
{
    if(index < start || index > end)
        return;
    if(start == end)
    {
        tree[node] = change;
        return;
    }
    int mid = (start + end) / 2;
    update(start, mid, node * 2, index, change);
    update(mid + 1, end, node * 2 + 1, index, change);
    tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % div_num;
}
int main(void)
{
    cin.tie(NULL);
    cout.tie(NULL);
    ios::sync_with_stdio(false);
    int n, m, k;
    cin >> n >> m >> k;
    for(int i = 1; i <= n; i++)
        cin >> arr[i];
    init(1, n, 1);
    for(int i = 0; i < m + k; i++)
    {
        int a;
        cin >> a;
        if(a == 1)
        {
            int b;
            long long int c;
            cin >> b >> c;
            update(1, n, 1, b, c);
        }
        else if(a == 2)
        {
            int b, c;
            cin >> b >> c;
            cout << multiplex(1, n, 1, b, c) << "\n";
        }
    }
}

문제 결과 result

문제 출처 https://www.acmicpc.net/problem/11505

댓글남기기