[백준 13544 | 머지 소트 트리] 수열과 쿼리 3

2020. 6. 21. 15:30Computer Science/Problem Solving (PS)

 

 

 

풀이

 

세그먼트트리를 응요한 자료구조인 머지 소트 트리(Merge Sort Tree)를 이용하면 O((logn)^2)시간 안에 해결할 수 있는 문제입니다. 머지 소트트리는 간단하게 말해서 트리의 각 노드에 자식들의 최솟값이나 최댓값을 저장하는 것이 아니라, 머지소트(Merge Sort)시에 일어나는 각 배열들의 중간 상태를 저장하는 자료 구조입니다. 

 

 

http://ivandemarino.me/2010/01/06/the-polite-merge-sort/

 

[5, 2, 4, 6, 1, 3, 2, 6]의 배열을 머지소트를 이용해서 정렬할 경우, 위의 그림과 같은 과정을 거치게 됩니다. 이 때, 각 노드의 상위 노드에는 두 배열을 머지하여 생긴 새로운 배열을 저장합니다. 즉 배열(혹은 리스트)의 배열 구조를 띄게 되는 것입니다.

 

 

이렇게 머지소트트리를 생성하고 난 후에는 쿼리를 날려서, 각 구간안에 포함되는 배열들 중에 k보다 큰 값이 얼마나 있는지를 확인해 줍니다. 예를 들어서 위의 그림에서 2~6 구간(시작 인덱스를 0이라 했을 때) 중에 3보다 큰 값을 찾아야 한다고 하겠습니다. 우선 머지 소트트리는 세그먼트 트리의 변형이기 때문에 해당 구간이 [4, 6], [1, 3], [2] 의 3가지 구간의 조합임을 O(logn)타임만에 알 수 있습니다. 

 

 

이때 이 각 구간에서 k보다 큰 값의 개수들을 리턴하도록 하고, 이들을 모두 합하면 2~6 구간 전체에서 k보다 큰 값의 개수가 몇 개인지를 구할 수 있는 것입니다. 각 구간은 머지 소트의 원리에 의해서 정렬되어 있는 상태이기 때문에 upper_bound함수를 구현해서(C++에서는 라이브러리로 구현되어 있으나, JAVA는 그렇지 않아서 직접 구현해야 했습니다.) k보다 큰 값이 처음으로 나오는 인덱스를 해당 구간의 전체 길이에서 빼는 방식으로 각 구간에서 k보다 큰 값의 개수를 구하였습니다. 

 

 

소스 코드 (JAVA)

import java.util.*;
import java.io.*;

public class BOJ_13544 {
    public static void main(String[] args) throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer tok = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(tok.nextToken());
        int[] arr = new int[n];

        tok = new StringTokenizer(br.readLine());
        for(int i=0;i<n;i++) arr[i] = Integer.parseInt(tok.nextToken());
        MergeSortTree mst = new MergeSortTree(n, arr);

        tok = new StringTokenizer(br.readLine());
        int m = Integer.parseInt(tok.nextToken());
        int ans = 0;
        for(int i=0;i<m;i++) {
            tok = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(tok.nextToken()) ^ ans;
            int end = Integer.parseInt(tok.nextToken()) ^ ans;
            int k = Integer.parseInt(tok.nextToken()) ^ ans;
            ans = mst.getGreater(start-1, end-1, 1, 0, n-1, k);
            bw.write(ans  + "\n");
        }
        bw.flush();
        bw.close();
    }
}

class MergeSortTree {
    int n;
    int[] arr;
    ArrayList<Integer>[] tree;

    MergeSortTree(int n, int[] arr) {
        this.n = n;
        this.arr = arr;
        this.tree = new ArrayList[n * 4];
        this.init(0, n-1, 1);
    }

    public ArrayList<Integer> init(int left, int right, int pos) {
        // in the case of the leaf node
        if (left == right) {
            tree[pos] = new ArrayList<>();
            tree[pos].add(arr[left]);
            return tree[pos];
        }
        int mid = (left + right) / 2;
        ArrayList<Integer> leftArr = init(left, mid, pos * 2);
        ArrayList<Integer> rightArr = init(mid+1, right, pos * 2 + 1);
        return tree[pos] = merge(leftArr, rightArr);
    }

    public int getGreater(int left, int right, int pos, int posLeft, int posRight, int val) {
        if (right < posLeft || posRight < left) return 0;
        if (left <= posLeft && posRight <= right) return tree[pos].size() - upper_bound(tree[pos], val);
        int mid = (posLeft + posRight) / 2;
        return getGreater(left, right, pos * 2, posLeft, mid, val)
                + getGreater(left, right, pos * 2 + 1, mid+1, posRight, val);
    }

    public static int upper_bound(ArrayList<Integer> arrayList, int val) {
        int len = arrayList.size();
        int left = 0, right = len-1, mid = 0;
        while(left < right) {
            if (arrayList.get(mid) <= val) left = mid + 1;
            else right = mid;
            mid = (left + right) / 2;
            if (mid == right) {
                if (arrayList.get(mid) <= val) return len;
                else return right;
            }
        }
        if(arrayList.get(left) > val) return 0;
        else return left + 1;
    }

    public ArrayList<Integer> merge(ArrayList<Integer> leftArr, ArrayList<Integer> rightArr) {
        ArrayList<Integer> returnArr = new ArrayList<>();
        int i=0, j=0, leftLen = leftArr.size(), rightLen = rightArr.size();
        while(i < leftLen && j < rightLen) {
            if (leftArr.get(i) <= rightArr.get(j)) {
                returnArr.add(leftArr.get(i));
                i++;
            } else {
                returnArr.add(rightArr.get(j));
                j++;
            }
        }
        while(i < leftLen) {
            returnArr.add(leftArr.get(i));
            i++;
        }
        while(j < rightLen) {
            returnArr.add(rightArr.get(j));
            j++;
        }
        return returnArr;
    }
}
반응형