백준 1865(웜홀) - python

2023. 10. 3. 23:35알고리즘(python)/최단경로

접근

기존 문제를 풀면서 다익스트라 알고리즘에 익숙해졌다. 익숙해진 다익스트라 알고리즘으로 접근했다.

해당 문제에서 내가 주의있게 본 부분은 웜홀은 시간이 줄어든다는 것, 즉 음수의 가중치가 있다는 점이었다.

다익스트라 알고리즘에 대해 생각해보면 전통적인 방법, heap을 사용하는 방법, heap을 사용하는 방법에서 재방문을 허용하는지, 하지않는지 3가지 정도 방법을 나뉜다.

나는 보통 heap을 사용하며 재방문을 허용하는 방식으로 변형해서 다익스트라 알고리즘을 사용한다. 가중치가 그리디한 방식으로 주어질 때 동일한 성능을 보이고 만약 가중치가 무작위로 주어지는 경우까지 답을 구할 수 있기 때문이다. 음의 가중치가 주어지는 경우에도 성능은 좋지 않지만 (처음 최단거리라고 생각하고 관련된 edge를 모두 힙에 넣었는데 나중에 그보다 작은 최단거리가 나오면 또 힙에 넣어서 연산해야해서 오버헤드가 크게 발생한다.) 풀 수 있다고 생각해서 다익스트라 알고리즘으로 접근했다.

하지만 주의해야할 점은 다익스트라 알고리즘으로 푸는 경우 음의 사이클이 발생하는 경우 무한루프에 빠지게 된다.

사이클이 발생하는 경우를 판별하기 위해 비용이 음수인 노드를 기억했다가 다시 그 노드를 다시 방문하면 빠져나오도록 했다.

그리고 만약 사이클이 있다면 문제에 시작 노드를 제시하지 않았기 때문에 출발하였던 위치로 돌아왔을때 시간이 줄었다고 생각할 수 있다.

 

나의 코드(실패)

import sys
input = sys.stdin.readline
from heapq import heappush,heappop
INF = 100000

def dik(start, n):
    d = [INF]*(n+1)
    checklist = [False]*(n+1)
    h = []
    heappush(h,(0,start))
    d[start] = 0
    while h:
        value1 , v1 = heappop(h)
        if d[v1] < value1:
            continue
        if d[v1] < 0:
            if v1 == start:
                return -1
            else:
                if not checklist[v1]:
                    checklist[v1] = True
                else:
                    return -1
        for value2, v2 in G[v1]:
            nextvalue = value1 + value2
            if d[v2] > nextvalue:
                d[v2] = nextvalue
                heappush(h,(nextvalue,v2))
    return 1


TC= int(input())

for _ in range(TC):
    n, m , w= map(int,input().split())
    G = [[] for _ in range(n+1)]

    for _ in range(m):
        s, e, time = map(int,input().split())
        G[s].append((time,e))
        G[e].append((time,s))

    for _ in range(w):
        s, e, time = map(int,input().split())
        G[s].append((-time,e))

    print("NO" if dik(1,n) == 1 else "YES")

 

해당 코드를 작성하고 틀리면서 틀린 원인을 생각해보았다.

1. 그래프가 모두 연결되어있지않을 수도 있다.

2. 사이클을 판단하기 위해서 checklist로 비용이 음수인 노드를 기억하게 하고 checklist[v1] 이 이미 방문 이력이 있으면 사이클이 있다고 판단하는 코드를 작성했는데 이것으로는 사이클을 판별할 수 없다.

 

재작성한 코드(다익스트라)

import sys
input = sys.stdin.readline
from heapq import heappush,heappop
INF = 100000

def dik(start, n):
    d = [INF]*(n+1)
    checklist = [False]*(n+1)
    h = []
    check = 0
    checknode = 0
    heappush(h,(0,start))
    d[start] = 0
    while h:
        value1 , v1 = heappop(h)
        if d[v1] < value1:
            continue
        if d[v1] < 0:
            if v1 == start:
                return -1
            else:
                if check == 0:
                    check = 1
                    checknode = v1
                else:
                    if checknode == v1:
                        return 1
        for value2, v2 in G[v1]:
            nextvalue = value1 + value2
            if d[v2] > nextvalue and not checklist[v2]:
                if nextvalue < 0 :
                    checklist[v2] = True
                    d[v2] = nextvalue
                    heappush(h,(nextvalue,v2))
                else:
                    d[v2] = nextvalue
                    heappush(h,(nextvalue,v2))
    return 1


TC= int(input())

for _ in range(TC):
    n, m , w= map(int,input().split())
    G = [[] for _ in range(n+1)]

    for _ in range(m):
        s, e, time = map(int,input().split())
        G[s].append((time,e))
        G[e].append((time,s))

    for _ in range(w):
        s, e, time = map(int,input().split())
        G[s].append((-time,e))

    ch = 0
    for i in range(1,n+1):
        d= dik(i,n)
        if d == -1:
            ch = 1
            break
    print("NO" if ch == 0 else "YES")

무식하게 1을 해결하기 위해 모든 노드를 기준으로 다익스트라를 돌리고 사이클을 찾는 것이 아니라 제자리로 더 적은 시간으로 돌아올 수 있는지만 판별하는 코드를 작성하였고 정답 판정을 받을 수 있었다. 하지만 매우 효율이 안좋은 코드인 만큼 다른 방법을 탐색해야겠다는 생각을 했다.

 

벨만포드 알고리즘을 이용하여 사이클을 찾는것이 좋겠다는 생각을 했다.

재작성한 코드(벨만포드)

import sys
input = sys.stdin.readline
INF = 100000

def ford(start):
    d = [INF]*(n+1)
    d[start] = 0
    for i in range(n):
        for s,e,time in edge:
            if d[s] + time < d[e]:
                d[e] = d[s] + time
                if i == n-1:
                    return -1
                print(d)
    return 1


TC= int(input())

for _ in range(TC):
    n, m , w= map(int,input().split())
    edge = set()
    checknode = [False] * (n+1)
    for _ in range(m):
        s, e, time = map(int,input().split())
        edge.add((s,e,time))
        edge.add((e,s,time))

    for _ in range(w):
        s, e, time = map(int,input().split())
        edge.add((s,e,-time))


    print("NO" if ford(1) == 1 else "YES")

일반적인 벨만포드와는 다르게 d[s] != INF 조건은 빼야한다. 아까 말했던 것처럼 모든 노드가 연결되어있다는 보장이 없어 한번의 벨만포드알고리즘을 이용해 답을 구하고자 하면 start로 둔 1 과 연결되지 않은 노드에 대해서도 음의 사이클이 있는지 확인해야한다.

문제를 여러번 풀다가 느낀것은 해당 문제는 결국 음의 사이클이 있는지만 확인하면 되는 문제라는 것이다. 경로를 최소한으로 업데이트하고 n번동안 경로를 기억하면서 최적화 할 수 있다고 생각했다.

 

최적화한 코드

import sys
input = sys.stdin.readline

TC= int(input())

for _ in range(TC):
    n, m , w= map(int,input().split())
    G = [[] for _ in range(n+1)]

    for _ in range(m):
        s, e, time = map(int,input().split())
        G[s].append((e,time))
        G[e].append((s,time))

    for _ in range(w):
        s, e, time = map(int,input().split())
        G[s].append((e,-time))

    cost = [0] * (n+1)
    min_cost = [0] * (n+1)
    *arr, =  range(1,n+1)
    visit = set()
    min_index =[]
    rc = 1
    for _ in range(n):
        temp = tuple(arr)
        if temp in visit:
            break
        visit.add(temp)
        while arr:
            index = arr.pop()
            for st, value in G[index]:
                if cost[index] + value < min_cost[st]:
                    min_cost[st] = cost[index] + value
                    min_index.append(st)
        if min_index:
            for z in min_index:
                cost[z] = min_cost[z]
        else:
            break
        arr, min_index = min_index, arr

    if arr:
        rc = 0
    print("NO" if rc == 1 else "YES")