Problem
알고리즘 교수님은 학생들에게 병합 정렬을 이용해 오름차순으로 정렬하는 과제를 내려고 한다.
정렬 된 결과만으로는 실제로 병합 정렬을 적용했는지 알 수 없기 때문에 다음과 같은 제약을 주었다.
N개의 정렬 대상을 가진 리스트 L을 분할할 때 L[0:N//2], L[N//2:N]으로 분할 한다.
병합 과정에서 다음처럼 왼쪽 마지막 원소가 오른쪽 마지막 원소보다 큰 경우의 수를 출력한다.
정렬이 끝난 리스트 L에서 L[N//2] 원소를 출력한다.
알고리즘 교수님의 조건에 따라 병합 정렬을 수행하는 프로그램을 만드시오.
[입력]
첫 줄에 테스트케이스의 수 T가 주어진다. 1<=T<=50
다음 줄부터 테스트 케이스의 별로 정수의 개수 N이 주어지고, 다음 줄에 N개의 정수 ai가 주어진다.
5<=N<=1,000,000, 0 <= ai <= 1,000,000
[출력]
각 줄마다 "#T" (T는 테스트 케이스 번호)를 출력한 뒤, N//2 번째 원소와 오른쪽 원소가 먼저 복사되는 경우의 수를 출력한다.
Solving
아이디어
- (재귀) 배열을 계속 반으로 나눈다. 0~mid 전까지 : left배열 / mid ~ last 까지 : right 배열
- 배열의 길이가 1일 때, 배열을 리턴
- 배열의 길이가 2 이상일 때, left 배열과 right 배열을 정렬하면서 합친다.
- 각 부분배열의 첫번째 인덱스부터 비교한다. 더 작은 값을 배열에 넣는다. (만약 값이 같으면 왼쪽 배열의 값을 넣는다. 왼쪽 배열의 원소가 오른쪽 배열의 원소보다 클 때만 오른쪽 배열의 원소를 넣고, 그렇게 해야 마지막 원소 비교할 때 올바르게 카운팅을 할 수 있기 때문이다.)
- 두 부분배열 중 하나가 탐색이 끝나면 나머지 한 배열의 나머지 값들을 순서대로 배열에 넣는다.
- 합쳐진 배열을 리턴한다.
코드
def merge_sort(arr):
global cnt # 왼쪽 마지막 원소가 오른쪽 마지막 원소보다 큰 경우 cnt += 1
# 원소가 1개이면 그냥 그대로 return
# 이유 : 이미 정렬된 배열임
if len(arr) == 1:
return arr
# 배열의 원소가 2개 이상이면 배열을 두개로 나눠서 각각 정렬된 배열을 return 받는다.
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
l_len = len(left)
r_len = len(right)
# idx : 원본 배열의 index
# l_idx : 왼쪽 배열의 index
# r_idx : 오른쪽 배열의 index
idx = l_idx = r_idx = 0
# 정렬된 왼쪽 배열과 오른쪽 배열을
# 첫번째 index부터 마지막 index까지 비교하면서 작은 값 부터 가져온다.
# 두 배열 중 하나라도 탐색이 끝나면 반복문 종료
while l_idx < l_len and r_idx < r_len:
if left[l_idx] <= right[r_idx]:
arr[idx] = left[l_idx]
l_idx += 1
else:
arr[idx] = right[r_idx]
r_idx += 1
idx += 1
# 왼쪽 배열의 탐색이 끝났다면 (즉, 오른쪽 배열에 값이 남아있으면)
if l_idx == l_len:
# 가져오지 않은 값들을 다 가져온다.
for i in range(r_idx, r_len):
arr[idx] = right[i]
idx += 1
# 오른쪽 배열의 탐색이 끝났다면 (즉, 왼쪽 배열에 값이 남아있으면)
elif r_idx == r_len:
# 왼쪽 배열의 마지막 원소가 오른쪽 배열의 마지막 원소보다 크기때문에
# 왼쪽 배열에 값이 남아있는 것이다. -> 카운팅
cnt += 1
# 가져오지 않은 값들을 다 가져온다.
for i in range(l_idx, l_len):
arr[idx] = left[i]
idx += 1
# 병합 정렬된 배열을 리턴
# 마지막 return인 경우 : 최종적으로 병합 정렬된 배열을 리턴
# 그렇지 않은 경우 : left or right 변수로 병합 정렬된 중간 결과물을 리턴
return arr
for tc in range(1, int(input()) + 1):
N = int(input())
in_arr = list(map(int, input().split()))
cnt = 0
print(f'#{tc} {merge_sort(in_arr)[N//2]} {cnt}')