백준 1967(트리의 지름) - python

2023. 10. 5. 21:39알고리즘(python)/BFS,DFS

접근

트리에서 임의의 한점이 선택된다고 생각하자. 그렇다면 거기서부터 어떻게 트리의 최대 거리를 구할 수 있을까? 최대 거리는 가장 먼 2개의 점을 고르는 것과 같다. 임의의 한점에서 가장 먼 노드로 간다면, 가장 먼 2개의 점중 한점에 가게된다. 그리고 그 두개의 점중에서 가장 먼 노드를 구한다면 어떤 점으로 처음 출발하더라도 가장 긴 지름을 구할 수 있다.

가장 먼 노드를 구하는 방법은 bfs, dfs(다익스트라) 모두 가능하다. 전에 1167 트리의 지름으로 이미 풀어본 문제인데 복습할겸 다양한 방법으로 풀어봤다. 요즘 다익스트라가 너무 익숙해져서 힙을 이용한 다익스트라로도 풀었는데 heappop은 시간복잡도가 O(logn)이고 모든 노드가 한번씩 방문되기 때문에 queue를 쓰는것과 계산상 다를게 없다는걸 깨달았다. 또 스택을 이용한 dfs로 풀었는데 가장 빠르게 풀렸고 재귀를 이용한 풀이는 for에 비해 느린걸 확인할 수 있었다. 함수가 실행될때마다 stack에 변수들을 세팅해야하니 더 느릴수없다는 걸 알았지만 코딩에 익숙해지고 싶어 다양한 방식으로 풀었다.

 

나의 코드(재귀를 이용한 dfs)

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**9)
n = int(input())
G = [[] for _ in range(n+1)]
leaf = set()
maxv = 0
maxi = 0
for _ in range(n-1):
    s, f , v = map(int,input().split())
    G[s].append((f,v))
    G[f].append((s,v))

for i in range(1,n+1):
    if len(G[i]) == 1:
        leaf.add(i)

def dfs(start,sumv):
    visit[start] = True
    if start in leaf:
        global maxv
        global maxi
        if maxv < sumv:
                maxv = sumv
                maxi = start
    else:
        for f, v in G[start]:
            if not visit[f]:
                dfs(f,sumv+v)

leaf.discard(1)
visit = [False]*(n+1)
dfs(1,0)
leaf.discard(maxi)
visit = [False]*(n+1)
dfs(maxi,0)

print(maxv)

나의 코드(스택를 이용한 dfs)

import sys
input = sys.stdin.readline
n = int(input())

G = [[] for _ in range(n+1)]
for _ in range(n-1):
    s, f , v = map(int,input().split())
    G[s].append((v,f))
    G[f].append((v,s))

def dik(start):
    visit = [-1] * (n+1)
    visit[start] = 0
    H = []
    H.append(start)
    maxi = 0
    maxv = 0
    while H:
        v1= H.pop()
        for value, v2 in G[v1]:
            if visit[v2] == -1:
                visit[v2] = visit[v1]+value
                H.append(v2)
                if visit[v1]+value > maxv:
                    maxi = v2
                    maxv = visit[v1]+value
    return (maxi,maxv)

mi,mv = dik(1)
mi,mv= dik(mi)
print(mv)

나의 코드(heap을 이용한 bfs(다익스트라))

import sys
input = sys.stdin.readline
from heapq import heappush,heappop
n = int(input())
INF = int(10e9)
G = [[] for _ in range(n+1)]
for _ in range(n-1):
    s, f , v = map(int,input().split())
    G[s].append((v,f))
    G[f].append((v,s))

def dik(start):
    d = [INF] * (n+1)
    H = []
    d[start] = 0
    heappush(H,(0,start))
    maxi = 0
    maxv = 0
    while H:
        value, v1= heappop(H)
        if d[v1] < value:
            continue
        for value2, v2 in G[v1]:
            nvalue = value2 + value
            if d[v2] > nvalue:
                d[v2] = nvalue
                if maxv < nvalue:
                    maxi = v2
                    maxv = nvalue
                heappush(H,(nvalue,v2))
    return (maxi,maxv)

mi,mv = dik(1)
mi,mv= dik(mi)
print(mv)