[two pointer] (3-5) list의 구간을 두개의 포인터로, 시간복잡도 줄이기

2024. 6. 17. 21:53개발의 흔적/코딩테스트

문제:  수들의 합

N개의 수로 된 수열 A[1], A[2], ..., A[N] 이 있다. 이 수열의 i번째 수부터 j번째 수까지의 합 A[i]+A[i+1]+...+A[j-1]+A[j]가 M이 되는 경우의 수를 구하는 프로그램을 작성하시오.

<입력설명>
첫째 줄에 N(1≤N≤10,000), M(1≤M≤300,000,000)이 주어진다. 다음 줄에는 A[1], A[2], ..., A[N]이 공백으로 분리되어 주어진다. 각각의 A[x]는 30,000을 넘지 않는 자연수이다.

<출력설명>
첫째 줄에 경우의 수를 출력한다.

입력예제

8 3

1 2 1 3 1 1 1 2

출력예제

5

 

위 문제는 1≤N≤10,000 의 제한을 가지고있다.

 

그렇기 때문에 combinations를 이용해 모든 조합 쌍을 찾는 방법은 애초에 불가능하고, 가능하다해도 비효율 적이다.

(10000C1 + 10000C2 + ..... + 10000C9999 + 100000C100000 을 계산하는것 자체에서 반드시 시간초과가 날 것)

 

위 문제의 case들을 머리속에 조금만 그려보면, 결국 구간끊기를 이용해야 한다는 생각에 다다른다.

어떠한 숫자 List의 시작구간과 마지막 구간을 기준으로 고려해 보겠다는 것이다.

예제의 경우를 통해 예를 들어보면,

len(list) == 8(즉, N)일 것이고, list[0:0] 부터 list[8:8]까지의 case가 생긴다.

모든 구간을 다 탐색한다고 일차적으로 생각하면, 

for i in range(N):

    for j in range(N):

list[start : end]의 시간복잡도는 O(start_point의_경우의_수) * O(end_point_경우의_수) == O(N) * O(N) == O(N^2)이다.

 

그러나 이 N이 10,000일 경우의 시간복잡도는 100,000,000, 즉 1억이므로 코드를 작성하기에 조금 애매하다. 까딱해서 중간에 추가 연산코드가 추가되면 시간초과의 border line에 걸릴 수도 있기 때문이지.

 

여기에 더해 위 계산 역시 결국 '모든 경우의 수를 모두 탐색해보겠다' 이기 때문에 비효율적이다.

 

그래서  두개의 pointer, 즉 start_point와 end_point 구간을 M을 기준으로 효율적이게 움직여야 하겠다 목적이 나온다.

그렇다, 결국 two pointer이다.

 

그래서 나온 첫 생각:

 

[i(start_point) : j(end_point)]를 [0:0] 일때부터 [N-1, N-1]까지 꼬물꼬물 움직일건데,

[i : j] 구간의 합이 M보다 작을때 / M 이상일 때로 상황을 나누고, 

M보다 작으면 j를 키우고, M 이상일때는 i를 키우자. 단 M과 같아지면 당연히 cnt에 찍어주고

 

위 생각을 슈도코드 형식으로 나타내보면 아래와 같다.

# list[i:j] /  마지막_index == N-1
cnt = 0
i, j = 0
while (i, j가 모두 N-1이 아니라면):
    while (list[i:j] 구간의 합이 M보다 작을때):
        #[i:j]까지의 합이 M 이상이 될때까지
        j += 1
    while (list[i:j] 구간의 합이 M 이상일때):
        if 딱 M이라면:
            cnt += 1
        #[i:j] 합이 M이나 혹은 그 이하가 될때까지
        i += 1
        #M보다 작아지면 위의 while문에서 j가 커지는 작업 반복
print(cnt)

 

2가지 상황은 특정상황(M미만 / M이상)일 경우 계속 되어야하므로 각각 while문.

그리고 그 두상황이 반복해서 일어나야 하므로 탐색이 끝나기 전까지(i,j가 N-1이 될때까지) 큰 while문 안에 넣어준다.

 

이렇게 되면 i와 j는 각각 0에서 시작해서 한번씩 N-1 까지 도달한다. 

list[start : end]의 시간복잡도는 O(start_point의_경우의_수)  +  O(end_point_경우의_수)

== O(N) + O(N) == O(2N) == O(N)

 

마지막 최종코드: 

 

import sys
input = sys.stdin.readline

N, M = map(int, input().split())
num_list = list(map(int, input().split()))

cnt = 0
sum = 0
index_i = 0
index_j = 0
while True:
    while sum < M and index_j < N:
        sum += num_list[index_j]
        if index_j != N - 1:
            index_j += 1

    while sum >= M:
        if sum == M:
            cnt += 1
            if index_j == N - 1:
                break
        sum -= num_list[index_i]
        index_i += 1

    if sum == M and index_j == N - 1:
        break
print(cnt)