Notice
Recent Posts
Recent Comments
Link
«   2026/04   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30
Tags
more
Archives
Today
Total
관리 메뉴

Elevation

트리 (4) - 센트로이드 본문

ps/tree

트리 (4) - 센트로이드

aste999 2026. 3. 24. 00:05

본래 무게중심이라는 뜻의 단어인 센트로이드(centroid)는, 트리에서 하나의 정점을 제거했을 때 나눠지는 서브 트리의 정점의 개수가 원래 정점의 개수의 절반 이하가 되도록 하는 정점으로 정의된다. 따라서 센트로이드를 찾으면 트리를 균형 있게 분할하는 것이 가능하다. 몇 가지 센트로이드에 대한 직관적이고도 중요한 성질들이 있다.

 

  • 트리에는 센트로이드가 항상 존재한다.
  • 센트로이드는 1개 혹은 2개 존재한다.
  • 센트로이드가 2개 존재하는 경우, 두 센트로이드는 서로 인접해 있다.

 

트리에서 분할 정복을 수행하려는 경우, 고른 root에 따라 최악에는 트리의 깊이가 정점 수 $N$까지 늘어날 수 있어 시간이 오래 걸릴 우려가 있다. 이럴 때 전체 트리의 센트로이드 $u$를 찾은 뒤, $u$에 대해 처리를 완료한 후 분할된 각 서브트리들의 센트로이드를 찾는 과정을 반복하면 트리의 깊이가 깊어지는 것을 방지할 수 있을 것이다. 이러한 알고리즘을 센트로이드 분할이라 한다.

 

 

센트로이드 분할

센트로이드 분할의 구현은 상당히 직관적이다. 시작 정점을 하나 잡아 서브트리의 정점의 개수(size)를 dfs로 먼저 구해 주고, 한번 더 dfs를 돌면서 $u$와 모든 인접한 정점들 $v$에 대해 $size[v]$가 총 정점 수의 절반을 넘지 않는다면 $u$를 센트로이드로 선정하면 된다. visited 배열을 관리하면서 센트로이드에 대해 처리한 후에는 나눠진 각 서브트리들에 대해 다시 분할을 수행한다.

size = [0]*N
visited = [False]*N
def get_size(u, p):
    size[u] = 1
    for v in adj[u]:
        if not visited[v] and v != p:
            get_size(v, u)
            size[u] += size[v]

def get_centroid(u, p, n):
    for v in adj[u]:
        if not visited[v] and v != p and size[v] > n//2:
            return get_centroid(v, u, n)
    return u

def decompose(root):
    get_size(root, -1)
    cent = get_centroid(root, -1, size[root])

    # ... 

    visited[cent] = True
    for v in adj[cent]:
        if not visited[v]: decompose(v)

 

 

다만 $N=10^5$ 정도의 제한에서는 재귀 기반 dfs가 메모리/시간 측면에서 상당한 부담이기 때문에, 다음과 같이 스택을 이용해서 size를 구하고 센트로이드를 찾는 것이 보다 바람직하다. 방문 순서를 기록해 놓고 그 역순으로 size값을 합쳐 주면 된다.

visited = [False]*(N+1)
size = [0]*(N+1)
def get_centroid(root):
    stack = [(root, -1)]
    visit_order = []
    while stack:
        u, p = stack.pop()
        visit_order.append((u, p))
        size[u] = 1
        for v in adj[u]:
            if v != p and not visited[v]: stack.append((v, u))
    
    for u, p in reversed(visit_order):
        if p != -1: size[p] += size[u]
    
    node = root
    p = -1
    while 1:
        val = True
        for v in adj[node]:
            if v == p or visited[v]: continue
            if size[v] > size[root]//2:
                val = False
                p = node
                node = v
                break
        if val: return node

 

 

센트로이드 트리

트리에 업데이트 쿼리가 들어오는 경우에는 센트로이드 분할의 결과를 미리 저장해 놓고 업데이트를 처리해야 할 것이다. 따라서 이러한 경우에는 전체 트리의 센트로이드를 루트 정점으로 하여, 분할된 서브 트리들의 센트로이드 각각을 원본 트리의 센트로이드와 연결해 주어 새로운 트리를 만들 수 있다. 이를 센트로이드 트리라고 부른다. 센트로이드 트리는 다음과 같은 중요한 성질을 갖는다.

 

  • 트리의 총 정점의 개수는 유지되며, 트리의 깊이는 최대 $\log N$이다.
  • 임의의 두 정점 $u, v$에 대해 원본 트리에서 $u, v$를 잇는 경로는 반드시 센트로이드 트리에서 $LCA(u, v)$를 지난다.

두 번째 성질은 자연스러우면서도 매우 흥미롭다. $LCA(u, v)$는 $u$와 $v$를 서로 다른 서브트리로 갈라 놓은 정점이므로 $u, v$를 잇는 경로는 반드시 해당 정점을 지날 수밖에 없다. 한편 이 성질은 센트로이드 트리가 원본 트리와는 전혀 다름에도 불구하고 원본 트리에서 분할 정복을 수행하듯이 센트로이드 트리에서도 경로에 관해 분할 정복을 수행하는 것이 가능함을 의미한다. 일반적으로 두 정점의 경로에 관한 쿼리를 두 정점의 공통 조상에서 처리하는 것과 동일하게, 센트로이드 트리 내에서도 두 정점의 공통 조상에서 정보를 합쳐서 처리할 수 있는 것이다.

 

 

트리와 쿼리 5(BOJ 13514)를 보자. naive하게 떠올릴 수 있는 해결 방법은 정점별로 $dists[u]$=($u$를 루트로 하는 서브트리 내에서, 흰색 정점들과의 거리 리스트)를 관리하는 것이다. 다만 원본 트리의 구조를 그대로 이용하면 업데이트와 최솟값 출력에 $O(N)$ 이상이 소요될 수밖에 없다. 따라서 해당 아이디어를 그대로 가져오되, 센트로이드 트리 내에서 쿼리를 수행한다.

 

  • 1번 쿼리: 대상 정점 $u$부터 센트로이드 트리를 타고 올라가면서, 모든 조상 $p$에 대해 $dists[p]$에서 $dist(u,p)$를 추가 or 삭제한다.
  • 2번 쿼리: 대상 정점 $u$의 모든 조상 $p$에 대해 $res=\min(res, \min(dists[p])+dist(u,p))$

 

추가, 삭제, 최솟값 확인을 빠르게 수행하기 위해 우선순위 큐를 사용하면, 결과적으로 두 쿼리 모두 $O(\log ^2 N)$만에 처리가 가능한 것을 확인할 수 있다.

cent_parent = [0]*(N+1)
def decompose(root, p):
    cent = get_centroid(root)
    cent_parent[cent] = p
    visited[cent] = True
    for v in adj[cent]:
        if not visited[v]: decompose(v, cent)
decompose(1, 0)

dist_list = [[] for _ in range(N+1)]
remove_list = [[] for _ in range(N+1)]
color = [False]*(N+1)
for _ in range(M):
    q1, u = map(int, input().split())
    if q1 == 1:
        color[u] = not color[u]
        p = u
        while p != 0:
            if color[u]: heapq.heappush(dist_list[node], dist(u, p))
            else: heapq.heappush(remove_list[node], dist(u, p))
            p = cent_parent[p]
    else:
        res = float('inf')
        p = u
        while p != 0:
            while remove_list[p] and remove_list[p][0]==dist_list[p][0]:
                heapq.heappop(dist_list[p])
                heapq.heappop(remove_list[p])
            if dist_list[p]: res = min(res, dist_list[p][0] + dist(u, p))
            p = cent_parent[p]
        if res == float('inf'): print(-1)
        else: print(res)