Post

최근접점 찾기

분할 정복 알고리즘을 이용하여 최근접점 찾기

구현하기 정말 어려웠던 알고리즘 중 하나였다. 아래의 코드를 뜯어보면서 다시 한 번 공부해보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import math

# 두 점 사이의 거리를 계산하는 함수
def calculate_distance(point1, point2):
    return ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) ** (1/2)

# 가장 가까운 두 점 사이의 거리를 찾는 함수
def find_closest_pair_distance(points):
    # 점들의 배열
    n = len(points)

    # 기본 케이스: 점의 개수가 3개 이하일 때 거리 비교해서 가장 짧은 거리 리턴
    if n <= 3:
        min_distance = float('inf') # 무한대 float형 선언
        for i in range(n):
            for j in range(i + 1, n):
                min_distance = min(min_distance, calculate_distance(points[i], points[j]))
        return min_distance

    # 점을 x좌표에 따라 정렬
    sorted_points = sorted(points, key=lambda x: x[0])

    # 왼쪽 영역과 오른쪽 영역에서 가장 가까운 거리를 재귀적으로 계산
    mid = n // 2
    left_distance = find_closest_pair_distance(sorted_points[:mid])
    right_distance = find_closest_pair_distance(sorted_points[mid:])

    # 두 부분의 최소 거리를 찾음
    min_distance = min(left_distance, right_distance)

    # 중간 영역에 걸쳐있는 점들 중 최소 거리를 찾음
    strip = []
    for i in range(n):
        if abs(sorted_points[i][0] - sorted_points[mid][0]) < min_distance:
            strip.append(sorted_points[i])

    # strip 영역에서 최소 거리를 찾음
    strip_min_distance = find_strip_closest_distance(strip, min_distance)

    # 최종적으로 최소 거리 반환
    return min(min_distance, strip_min_distance)

# 중간 영역에 걸쳐있는 점들 중 최소 거리를 찾는 함수
def find_strip_closest_distance(strip, min_distance):
    n = len(strip)
    min_strip_distance = min_distance  # 초기 최소 거리를 할당

    # strip 영역의 점들을 y 좌표를 기준으로 정렬
    strip.sort(key=lambda point: point[1])

    # strip 영역의 점들을 순회
    for i in range(n):
        for j in range(i + 1, n):
            # y 좌표의 차이가 현재 최소 거리보다 크거나 같으면 더 이상 비교할 필요가 없음
            if strip[j][1] - strip[i][1] >= min_strip_distance:
                break

            # 두 점 사이의 거리 계산
            distance = calculate_distance(strip[i], strip[j])

            # 더 작은 거리를 찾으면 최소 거리 업데이트
            min_strip_distance = min(min_strip_distance, distance)

    # strip 영역에서 찾은 최소 거리 반환
    return min_strip_distance

n = int(input().strip())  # 전체 점의 개수
points = []  # 각 점의 좌표를 저장할 리스트
for _ in range(n):
    # 각 점의 x, y 좌표를 입력받아 리스트에 추가
    x, y = map(int, input().strip().split(','))
    points.append((x, y))

# 가장 가까운 두 점 사이의 거리 계산
closest_distance = find_closest_pair_distance(points)

# 결과를 소수점 아래 6자리로 출력
print('{:.6f}'.format(closest_distance))

1. 두 점 사이의 거리를 계산해주는 함수

1
2
def calculate_distance(point1, point2):
    return ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) ** (1/2)

두 점의 좌표가 필요하므로, x축 좌표와 y축 좌표의 정보가 담긴 배열을 가진 두 점을 요소로 갖는 함수를 만들고, 수식을 그대로 넣어주면 된다. 제곱근을 구할 때, 라이브러리를 불러와도 되지만, 그냥 저렇게 2분의1을 곱해줘도 똑같은 값이 나온다. 대신 괄호 안에 꼭 넣어줘야 한다.

2. 가장 가까운 두 점 사이의 거리를 찾는 함수

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def find_closest_pair_distance(points):
    # 점들의 배열
    n = len(points)

    # 기본 케이스: 점의 개수가 3개 이하일 때 거리 비교해서 가장 짧은 거리 리턴
    if n <= 3:
        min_distance = float('inf') # 무한대 float형 선언
        for i in range(n):
            for j in range(i + 1, n):
                min_distance = min(min_distance, calculate_distance(points[i], points[j]))
        return min_distance

    # 점을 x좌표에 따라 정렬
    sorted_points = sorted(points, key=lambda x: x[0])

    # 왼쪽 영역과 오른쪽 영역에서 가장 가까운 거리를 재귀적으로 계산
    mid = n // 2
    left_distance = find_closest_pair_distance(sorted_points[:mid])
    right_distance = find_closest_pair_distance(sorted_points[mid:])

    # 두 부분의 최소 거리를 찾음
    min_distance = min(left_distance, right_distance)

    # 중간 영역에 걸쳐있는 점들 중 최소 거리를 찾음
    strip = []
    for i in range(n):
        if abs(sorted_points[i][0] - sorted_points[mid][0]) < min_distance:
            strip.append(sorted_points[i])

    # strip 영역에서 최소 거리를 찾음
    strip_min_distance = find_strip_closest_distance(strip, min_distance)

    # 최종적으로 최소 거리 반환
    return min(min_distance, strip_min_distance)

먼저 점들의 좌표 정보가 담긴 배열을 함수의 요소로 입력받는다.

  1. 점의 개수가 3개 이하일 때
    재귀 함수를 이용해야하기 때문에, 케이스를 둘로 나누어서 생각한다. 점의 개수가 3개 이하일 때는 그 점들 사이의 거리를 다 계산해서 가장 짧은 값을 리턴한다. 그 과정에서 min_distance의 초기값으로 float('inf')를 이용했는데, 이는 가장 min()함수를 이용할 것이기 때문에, 무한대의 값을 넣어준 것이다. 나중에 for문을 이용하여 비교할 것 이기 때문에..
  2. 그 이상일 때
    1. 먼저 점들을 x좌표에 따라 정렬한다.
      여기서 쓰이는 sorted(points, key=lambda x: x[0]) 이 정렬 방식을 알아둘 필요가 있다. 2차원 배열을 정렬할 때는 해당 소속 배열들의 특정 요소를 기준으로 정렬하는 이러한 법이 있었다.
    2. 왼쪽, 오른쪽 영역을 나누고, 가장 가까운 거리를 재귀적으로 계산한다.
      mid를 기준으로 배열을 나눈다. 쉽게 말해서, ‘:’이 왼쪽에 있으면 왼쪽 영역, ‘:’이 오른쪽에 있으면 오른쪽 영역이다. 참고로 배열이 홀수 개일 경우, 중앙값은 오른쪽 영역으로 가게된다. find_closest_pair_distance함수를 재귀적으로 호출하여 배열의 길이가 3 이하가 될 때까지 호출하여 리턴값을 받아낸다.
    3. 두 부분의 최소 거리를 비교한다.
      min()함수를 이용한다.
    4. 중간 영역에 걸쳐있는 점들 중 최소 거리를 찾는다.
      중간 영역에 걸쳐있는 점들의 거리를 체크한다. 두 영역 중 가장 짧은 거리를 뜻하는 min_distance보다 x좌표의 차이가 작은 점들만 strip[]에 넣어준다. 여기서 그냥 기준을 sorted_points[mid][0] 이 점으로 하면 된다. 항상 가운데, 혹은 가운데에서 한칸 오른쪽인 점이기 때문이다.
    5. find_strip_closest_distance()함수를 이용하여 최소거리를 찾아준다.
    6. 최종 리턴

3. 중간 영역에 걸쳐있는 점들 중 최소 거리를 찾는 함수

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def find_strip_closest_distance(strip, min_distance):
    n = len(strip)
    min_strip_distance = min_distance  # 초기 최소 거리를 할당

    # strip 영역의 점들을 y 좌표를 기준으로 정렬
    strip.sort(key=lambda point: point[1])

    # strip 영역의 점들을 순회
    for i in range(n):
        for j in range(i + 1, n):
            # y 좌표의 차이가 현재 최소 거리보다 크거나 같으면 더 이상 비교할 필요가 없음
            if strip[j][1] - strip[i][1] >= min_strip_distance:
                break

            # 두 점 사이의 거리 계산
            distance = calculate_distance(strip[i], strip[j])

            # 더 작은 거리를 찾으면 최소 거리 업데이트
            min_strip_distance = min(min_strip_distance, distance)

    # strip 영역에서 찾은 최소 거리 반환
    return min_strip_distance

for문 안에서 i(0부터 n-1까지의 수)번째 점들이 자신보다 위에 있는 점들과 비교해가면서 현재 최소 거리보다 크면 날리고, 작으면 거리를 계산하여 최소 거리 업데이트를 한다.

알게된 것 정리

제곱근 구하는 법

1
2
def calculate_distance(point1, point2):
    return ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) ** (1/2)

무한대 값을 표현하는 법 float('inf')

2차원 배열에서 어떠한 차원(x차원,y차원)을 기준으로 정렬하는 법
sorted(points, key=lambda x: x[0])

배열 나누기(mid를 기준으로 :이 왼쪽에 있으면 왼쪽 부분, 오른쪽에 있으면 오른쪽 부분! 길이가 홀수일 때는 가운데 요소가 오른쪽 부분으로 감)

1
2
3
mid = n // 2
left_distance = find_closest_pair_distance(sorted_points[:mid])
right_distance = find_closest_pair_distance(sorted_points[mid:])
This post is licensed under CC BY 4.0 by the author.