algorithm/Algorithm-Core

[Union Find] - Disjoint Set

아르비스 2020. 8. 6. 10:02

간선으로 연결하려는 두 노드가 이미 같은 집합으로 포함된 것인지를 판단해줌.

 

간략 코드

static class DJS {
        int[] mom;
        DJS(int n) {
            mom = new int[n+1];
            for (int i = 1; i <= n ; i++) mom[i] = i;
        }
        int find(int n) {
            if(n==mom[n]) return n;
            return mom[n] = find(mom[n]);
        }
        boolean union(int a, int b) {
            a = find(a);
            b = find(b);
            if(a==b) return false;
            mom[b] = a;
            return true;
        }
    }

 

 

sample Code

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class pre19091 {
    static int N, M, S, E;
    static Node[] nodes;
    static DJS djs;
    static class Node implements Comparable<Node> {
        int s, e, c;
        Node(int s, int e, int c) {
            this.s = s;
            this.e = e;
            this.c = c;
        }

        @Override
        public int compareTo(Node o) {
            return this.c - o.c;
        }
    }
    static class DJS {
        int[] mom;
        DJS(int n) {
            mom = new int[n+1];
            for (int i = 1; i <= n ; i++) mom[i] = i;
        }
        int find(int n) {
            if(n==mom[n]) return n;
            return mom[n] = find(mom[n]);
        }
        boolean union(int a, int b) {
            a = find(a);
            b = find(b);
            if(a==b) return false;
            mom[b] = a;
            return true;
        }
    }
    public static void main(String[] args) throws IOException {
//        System.setIn(new FileInputStream("res/input_pre0910.txt"));
        System.setIn(new FileInputStream("res/sample_input_0910.txt"));
        long Start = System.currentTimeMillis();
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;
        int T = Integer.parseInt(br.readLine().trim());
        for (int t = 1; t <= T ; t++) {
            st = new StringTokenizer(br.readLine().trim(), " ");
            N = Integer.parseInt(st.nextToken());
            M = Integer.parseInt(st.nextToken());
            nodes = new Node[M];

            int a, b, c;
            for (int i = 0; i < M; i++) {
                st = new StringTokenizer(br.readLine().trim(), " ");
                a = Integer.parseInt(st.nextToken());
                b = Integer.parseInt(st.nextToken());
                c = Integer.parseInt(st.nextToken());
                nodes[i] = new Node(a,b,c);
            }
            st = new StringTokenizer(br.readLine().trim(), " ");
            S = Integer.parseInt(st.nextToken());
            E = Integer.parseInt(st.nextToken());

            Arrays.sort(nodes);
            int result = Integer.MAX_VALUE;//1000000001;

            for (int i = 0; i < M ; i++) {
                result = Math.min(result, find(i));
            }

            System.out.println("#" + t + " " + result);
        }
        System.out.println("Total : " + (System.currentTimeMillis()-Start) + " ms");
    }

    static int find(int idx) {
       djs = new DJS(N);
        for (int i = idx; i < M ; i++) {
            if(nodes[i].c < nodes[idx].c) continue;
            djs.union(nodes[i].s, nodes[i].e);
            if(djs.find(S) == djs.find(E)) return (nodes[i].c - nodes[idx].c);
        }
        return Integer.MAX_VALUE;
    }
}