algorithm

LCA(Lowest Common Ancestor)

아르비스 2019. 8. 12. 08:55

LCA를 직역하면 최소 공통 조상(?) 정도의 뜻으로 해석되며 두 정점에서 (자신을 포함한)조상들을 거슬러 올라갈 때 처음으로 공통되게 만나는 정점을 지칭합니다.

즉, Tree에서 두점 간의 거리를 구하는 방식으로 사용됨.

 

트리에서 LCA(1,N)을 구하는 쿼리는 O(N)의 시간복잡도를 가지게 됩니다.

 

하지만 LCA2와 같은 문제에서 O(N)에 쿼리를 처리한다면 우리는 O(N*M)의 시간복잡도를 가지고 시간초과를 보게 될 것입니다.

 

쿼리의 개수인 M을 조절해 줄 수는 없으니 적어도 LCA를 O(logN)의 시간복잡도에 구해줄 무언가가 필요합니다

 

우리는 O(N)의 방식으로 LCA를 구하는 방법을 약간 변형 시켜서 O(logN)의 시간복잡도에 LCA를 구해보겠습니다.

 

우선 전처리가 필요한데요. O(N)의 방식으로 dfs를 돌리면서 각 정점의 깊이와 부모를 저장합니다.

 

이때 미리 말하자면 우리는 한 정점의 2^N의 부모를 저장할 것이기 때문에 par[x][i] 배열(x: 정점번호 i: 2^i번째 조상을 의미)을 선언 한뒤 par[x][0]에 부모를 채워줍니다. 

 

DFS에 의한 전처리가 끝나면 반복문을 이용하여 각 정점들의 (2^i)번째 조상을 채워줄 것입니다.

 

이는 par[x][i]=par[par[x][i-1]][i-1]이라는 수식을 이용하면 채울 수 있습니다.


시간 복잡도를 줄이는게 꼭 필요함.

 

static void dfs(int here,int depth, int dis) {
        visited[here] = true;
        dep[here] = depth;
        d[here] = dis;
        for (Node nd : vt[here]) {
            if (visited[nd.e])
                continue;
            par[nd.e][0] = here;
            dfs(nd.e, depth + 1, dis + nd.c);
        }
    }
    static void find() {
        for (int j = 1; j < 21; j++) {
            for (int i = 1; i <= N; i++) {
                par[i][j] = par[par[i][j - 1]][j - 1];
            }
        }
    }

j를 왜 20까지 구하냐고 물으신다면 N이 10만이기 때문에 2^20은 100만 이상의 수로 2^20번째 조상까지만 채워주더라도 충분하기 때문입니다.

 

자 이제 2^i번째 조상들을 저장하여 이것을 어떻게 이용할 것이냐고 물어보면 아까와 같은 행동을 반복할것입니다.

 

LCA를 구하려는 두 정점의 높이를 맞춰준 후 똑같이 올라가며 LCA를 구할 것입니다.

 

단 우리는 2^i번째 조상들을 구해놨기 때문에 2^i만큼씩 지수승 씩 높이를 증가시킬 수 있습니다

 

static int lca(int x, int y) {
        if (dep[x] > dep[y]) {
            int tmp = x;
            x = y;
            y = tmp;
        }
        for (int i = 20; i >= 0; i--) {
            if (dep[y] - dep[x] >= (1 << i))
                y = par[y][i];
        }
        if (x == y)return x;
        for (int i = 20; i >= 0; i--) {
            if (par[x][i] != par[y][i]) {
                x = par[x][i];
                y = par[y][i];
            }
        }
        return par[x][0];
    }

 

참조 문제

https://www.acmicpc.net/problem/1761

 

1761번: 정점들의 거리

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 거리를 알고 싶은 M개의 노드 쌍이 한 줄에 한 쌍씩 입력된다. 두 점 사이의 거리는 10,000보다 작거나 같은 자연수이다. 정점은 1번부터 N번까지 번호가 매겨져 있다.

www.acmicpc.net

lca를 이용하여 두 노드의 최단 거리를 계산할 것이다.

 

lca전처리를 위해 dfs를 돌려줄 때 루트에서 부터 자기 자신까지의 거리를 dist배열에 저장해두면

 

두 노드의 X,Y 의최단 거리는 dist[X]+dist[Y]-2*dist[lca(X,Y)]로 정의 할 수있다.

 

이를 이용하여 매 쿼리를 logN시간에 처리해주면 된다.




import java.io.*;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class LCA {
    static int N, Q;
    static final int MAX_N = 100000;
    static int[] d, dep;  // d: dist , dep : depth
    static int[][] par;
    static boolean[] visited;
    static ArrayList<Node>[] vt;
    static class Node {
        int e, c;
        Node(int e, int c) {
            this.e = e;
            this.c = c;
        }
    }
    static void dfs(int here,int depth, int dis) {
        visited[here] = true;
        dep[here] = depth;
        d[here] = dis;
        for (Node nd : vt[here]) {
            if (visited[nd.e])
                continue;
            par[nd.e][0] = here;
            dfs(nd.e, depth + 1, dis + nd.c);
        }
    }
    static void find() {
        for (int j = 1; j < 21; j++) {
            for (int i = 1; i <= N; i++) {
                par[i][j] = par[par[i][j - 1]][j - 1];
            }
        }
    }
    static int lca(int x, int y) {
        if (dep[x] > dep[y]) {
            int tmp = x;
            x = y;
            y = tmp;
        }
        for (int i = 20; i >= 0; i--) {
            if (dep[y] - dep[x] >= (1 << i))
                y = par[y][i];
        }
        if (x == y)return x;
        for (int i = 20; i >= 0; i--) {
            if (par[x][i] != par[y][i]) {
                x = par[x][i];
                y = par[y][i];
            }
        }
        return par[x][0];
    }
    public static void main(String[] args) throws IOException {
       // System.setIn(new FileInputStream("res/input_lca.txt"));
        long Start = System.currentTimeMillis();
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        N = Integer.parseInt(br.readLine().trim());
        par = new int[N+1][21];
        visited = new boolean[N+1];
        d = new int[N+1];
        dep = new int[N+1];
        vt = new ArrayList[N+1];
        int x, y, z;
        for (int i = 1; i <=N-1 ; i++) {
            st = new StringTokenizer(br.readLine().trim(), " ");
            x = Integer.parseInt(st.nextToken());
            y = Integer.parseInt(st.nextToken());
            z = Integer.parseInt(st.nextToken());
            if(vt[x]==null) vt[x] = new ArrayList<Node>();
            if(vt[y]==null) vt[y] = new ArrayList<Node>();
            vt[x].add(new Node(y,z));
            vt[y].add(new Node(x,z));
        }

        dfs(1, 0, 0);
        find();

        Q = Integer.parseInt(br.readLine().trim());
        int result = 0;
        for (int i = 1; i <= Q ; i++) {
            st = new StringTokenizer(br.readLine().trim(), " ");
            x = Integer.parseInt(st.nextToken());
            y = Integer.parseInt(st.nextToken());
            result = d[x]+d[y] - (2*d[lca(x,y)]);
            System.out.println(result);
        }
      //  System.out.println("Total : " + (System.currentTimeMillis()-Start) + "ms");
    }
}