Post

크루스칼 알고리즘

크루스칼 알고리즘

간선들의 비용을 기준을 정렬한 뒤, 사이클을 이루지 않게끔 최소 신장 트리를 형성하는 알고리즘

  1. 비용을 기준으로 간선들을 오름차순으로 정렬한다.
  2. 간선들을 하나씩 확인하며 사이클을 발생시키는지 확인한다.
    • 사이클을 발생시킬 경우 포함 x
    • 사이클을 발생시키지 않으면 최소 신장 트리에 포함시킨다.
  3. 모든 간선에 대하여 2번 과정을 반복한다.

크루스칼 알고리즘

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 사이클 여부 판별을 위한 union-find 함수들
def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

n, e = map(int, input().split()) # 노드의 개수와 간선의 개수
parent = [x for x in range(n+1)]
line = []

for _ in range(e):
    a, b, cost = map(int, input().split())
    line.append([cost, a, b]) # 비용을 기준으로 정렬하기 위해 맨 앞에 배치

line.sort()

total_cost = 0
for cost, a, b in line:
    if find(parent, a) != find(parent, b): # 사이클을 이루지 않으면,
        union(parent, a, b)
        total_cost += cost

print(total_cost)

백준 - 도시 분할 계획

마을은 N개의 집과 그 집들을 연결하는 M개의 길로 이루어져 있다. 길은 어느 방향으로든지 다닐 수 있는 편리한 길이다. 그리고 각 길마다 길을 유지하는데 드는 유지비가 있다. 임의의 두 집 사이에 경로가 항상 존재한다.

마을의 이장은 마을을 두 개의 분리된 마을로 분할할 계획을 가지고 있다. 마을이 너무 커서 혼자서는 관리할 수 없기 때문이다. 마을을 분할할 때는 각 분리된 마을 안에 집들이 서로 연결되도록 분할해야 한다. 각 분리된 마을 안에 있는 임의의 두 집 사이에 경로가 항상 존재해야 한다는 뜻이다. 마을에는 집이 하나 이상 있어야 한다.

그렇게 마을의 이장은 계획을 세우다가 마을 안에 길이 너무 많다는 생각을 하게 되었다. 일단 분리된 두 마을 사이에 있는 길들은 필요가 없으므로 없앨 수 있다. 그리고 각 분리된 마을 안에서도 임의의 두 집 사이에 경로가 항상 존재하게 하면서 길을 더 없앨 수 있다. 마을의 이장은 위 조건을 만족하도록 길들을 모두 없애고 나머지 길의 유지비의 합을 최소로 하고 싶다. 이것을 구하는 프로그램을 작성하시오.

입력
첫째 줄에 집의 개수 N, 길의 개수 M이 주어진다. N은 2이상 100,000이하인 정수이고, M은 1이상 1,000,000이하인 정수이다. 그 다음 줄부터 M줄에 걸쳐 길의 정보가 A B C 세 개의 정수로 주어지는데 A번 집과 B번 집을 연결하는 길의 유지비가 C (1 ≤ C ≤ 1,000)라는 뜻이다.

임의의 두 집 사이에 경로가 항상 존재하는 입력만 주어진다.

출력
첫째 줄에 없애고 남은 길 유지비의 합의 최솟값을 출력한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
input = sys.stdin.readline

def find_parent(parent, x):
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

def union_parent(parent, a, b):
    a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

n, m = map(int, input().split())

parent = [x for x in range(n+1)]
line = []

for _ in range(m):
    a, b, cost = map(int, input().split())
    line.append([cost, a, b])

line.sort()

total_cost = 0
save = 0
for cost, a, b in line:
    if find_parent(parent, a) != find_parent(parent, b):
        union_parent(parent, a, b)
        total_cost += cost
        save = cost # 연결된 간선 중 가장 큰 비용을 저장하기 위함

print(total_cost-save)

아주 기본적인 크루스칼 알고리즘 문제이다. 마을을 두 개로 분리하라고 했고, 마을을 이룰 수 있는 집의 최소 개수가 주어지지 않았기 때문에, 비용이 가장 큰 간선 하나를 끊어버리면 된다.

백준 - 별자리 만들기

도현이는 우주의 신이다. 이제 도현이는 아무렇게나 널브러져 있는 n개의 별들을 이어서 별자리를 하나 만들 것이다. 별자리의 조건은 다음과 같다.

  • 별자리를 이루는 선은 서로 다른 두 별을 일직선으로 이은 형태이다.
  • 모든 별들은 별자리 위의 선을 통해 서로 직/간접적으로 이어져 있어야 한다.

별들이 2차원 평면 위에 놓여 있다. 선을 하나 이을 때마다 두 별 사이의 거리만큼의 비용이 든다고 할 때, 별자리를 만드는 최소 비용을 구하시오.

입력
첫째 줄에 별의 개수 n이 주어진다. (1 ≤ n ≤ 100)

둘째 줄부터 n개의 줄에 걸쳐 각 별의 x, y좌표가 실수 형태로 주어지며, 최대 소수점 둘째자리까지 주어진다. 좌표는 1000을 넘지 않는 양의 실수이다.

출력
첫째 줄에 정답을 출력한다. 절대/상대 오차는 10^(-2)까지 허용한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys
input = sys.stdin.readline

n = int(input())

def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def distance(a, b):
    return ((abs(a[0]-b[0]))**2 + (abs(a[1]-b[1])**2))**(1/2)

line = []
point = []
parent = [x for x in range(n+1)]

for _ in range(n):
    x, y = map(float, input().split())
    point.append([x, y])

for i in range(n):
    for j in range(n):
        if i != j:
            cost = distance(point[i],point[j])
            line.append([cost, i, j])

total_cost = 0

line.sort()
for cost, a, b in line:
    if find(parent, a) != find(parent, b):
        union(parent, a, b)
        total_cost += cost

print(round(total_cost, 2))

거리가 곧 비용이므로, distance 함수만 추가해서 거리를 계산해주는 부분만 추가되었다.

백준 - 행성 터널

때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.

행성은 3차원 좌표 위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(abs(xA-xB), abs(yA-yB), abs(zA-zB))이다.

민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.

입력
첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -10^9보다 크거나 같고, 10^9보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다.

출력
첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.

메모리 초과 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
input = sys.stdin.readline

def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def distance(a, b):
    return min(abs(a[0]-b[0]), abs(a[1]-b[1]), abs(a[2]-b[2]))

n = int(input())
point = []
parent = [x for x in range(n+1)]

for _ in range(n):
    x, y, z = map(int, input().split())
    point.append([x, y, z])

line = []
for i in range(n):
    for j in range(n):
        if i != j:
            line.append([distance(point[i], point[j]), i, j])

line.sort()
total_cost = 0

for cost, i, j in line:
    if find(parent, i) != find(parent, j):
        union(parent, i, j)
        total_cost += cost

print(total_cost)

좌표의 범위가 넓기 때문에 간선들을 모두 저장해버리면 메모리를 초과하게 된다. (저장해야할 간선의 수 n(n-1))
x좌표, y좌표, z좌표를 각기 다른 리스트에 입력받아서, sort를 해주고, 순서대로 차이를 저장해주면 저장되는 간선의 수는 3(n-1)이 된다.

정답 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import sys
input = sys.stdin.readline

def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent, parent[x])
    return parent[x]

def union(parent, a, b):
    a = find(parent, a)
    b = find(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

n = int(input())
point = []
parent = [x for x in range(n)]

xlist = []
ylist = []
zlist = []

for i in range(n):
    x, y, z = map(int, input().split())
    xlist.append([x, i])
    ylist.append([y, i])
    zlist.append([z, i])

xlist.sort()
ylist.sort()
zlist.sort()

line = []

for i in range(1,n):
    dist1, a = xlist[i-1]
    dist2, b = xlist[i]
    line.append([abs(dist1 - dist2), a, b])

for i in range(1,n):
    dist1, a = ylist[i-1]
    dist2, b = ylist[i]
    line.append([abs(dist1 - dist2), a, b])

for i in range(1, n):
    dist1, a = zlist[i - 1]
    dist2, b = zlist[i]
    line.append([abs(dist1 - dist2), a, b])

line.sort()
total_cost = 0


for cost, i ,j in line:
    if find(parent, i) != find(parent, j):
        union(parent, i, j)
        total_cost += cost

print(total_cost)
This post is licensed under CC BY 4.0 by the author.