알고리즘 관련/BOJ

1967 트리의 지름 구하기

Andrew-Yun 2022. 2. 3. 22:09

https://www.acmicpc.net/problem/1967

 

1967번: 트리의 지름

파일의 첫 번째 줄은 노드의 개수 n(1 ≤ n ≤ 10,000)이다. 둘째 줄부터 n-1개의 줄에 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수는 간선이 연

www.acmicpc.net

모든 정점에서 BFS를 수행을 시도했더니 메모리 초과가 발생했다. 이유를 모르겠어서 정답을 참고했는데, 방식이 인상 깊어 기록으로 남기는게 좋을 것 같았다.

트리의 지름을 구하는 방법으로 루트 노드에서 가장 먼 정점을 구한다. 트리의 특성 상 절반으로 연산을 줄일 수 있기 때문이다. 이후 해당 정점에서 다시 가장 먼 정점의 거리를 구하면 트리의 지름이 된다.

import java.io.*;
import java.util.*;
/*
try 1: HashMap, ArrayList로 인접 리스트를 나타내어 각 정점마다 bfs를 수행 -> 메모리 초과
try 2: 트리의 지름 구하는 방법, 루트에서 가장 먼 정점을 고르고, 해당 정점에서 다시 가장 먼 정점을 고르면 지름이 된다.
*/
public class Main {
  public static int N, farIdx = 0;
  static ArrayList<int[]> graph[] = new ArrayList[10001];
  public static int bfs(int start){
    int[] dist = new int[N + 1];
    int maxDist = 0;
    Queue<Integer> q = new ArrayDeque<>();
    q.add(start);
    dist[start] = 0;
    while(!q.isEmpty()){
      int cur = q.poll();
      for(int[] edge : graph[cur]){
        int next = edge[0], cost = edge[1];
        if(dist[next] == 0 && next != start){
          dist[next] = dist[cur] + cost;
          maxDist = Math.max(maxDist, dist[next]);
          q.add(next);
        }
      }
    }
    int farDist = 0;
    for(int i = 1; i <= N; i++){
      if(farDist < dist[i]){
        farDist = dist[i];
        farIdx = i;
      }
    }
    return maxDist;
  }
  public static void main(String[] args) throws Exception {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    StringTokenizer st = new StringTokenizer(br.readLine());
    N = Integer.parseInt(st.nextToken());
    short from, to, dist;
    for(int i = 0; i < N + 1; i++)
      graph[i] = new ArrayList<>();
    for(int i = 0; i < N - 1; i++){
      st = new StringTokenizer(br.readLine());
      from = Short.parseShort(st.nextToken());
      to = Short.parseShort(st.nextToken());
      dist = Short.parseShort(st.nextToken());
      graph[from].add(new int[]{to, dist});
      graph[to].add(new int[]{from, dist});
    }
    if(N == 1){
      System.out.println(0);
      return;
    }
    bfs(1);
    int answer = bfs(farIdx);
    System.out.println(answer);
  }
}