[자료구조] 세그먼트 트리 | Segment Tree
세그먼트 트리(Segment Tree)란?
세그먼트 트리는 어떤 구간에 대한 정보(특정 구간에서의 최대/최솟값, 평균, 합 등)를
각각의 노드가 저장하고 있다는 특성을 이용해, O(logN)의 시간복잡도로 구할 수 있게 해주는 자료 구조입니다.
구현의 편의를 위해, 완전 이진트리 형태로 구현하였습니다.
먼저, 전체적인 구조는 아래와 같습니다.
주어진 노드(BaseNodes)의 개수 $N$에 대해 $N < X$를 만족하는 가장 가까운$2^N$값 X에 대해,
세그먼트 트리의 루트노드는 구간 [0,X]까지의 BaseNodes의 세그먼트 값(합)을 의미합니다.
또한, 이 BaseNode의 자식들은 각각 구간[0,X//2),[X//2,X)의 합을 저장하고 있습니다.
이러한 세그먼트 트리의 특성을 잘 이용하면, 우리는 단순히 선형탐색했을 때의 시간복잡도인 O(N)보다
월등히 빠른, O(logN)의 시간 복잡도를 얻을 수 있습니다.
그럼, 이제부터 세그먼트 트리의 선언, 갱신, 그리고 원하는 구간의 세그먼트 값을 구하는 Query의 구현이 어떻게 이루어지는지 알아보겠습니다.
1. 세그먼트 트리의 선언
완전 이진트리 구조로 세그먼트 트리를 선언하므로, 세그먼트 트리의 노드 개수는
$N_1$ = 1이고, r = 2 인 등비수열의 $N_1 + ... + N_{log_2X}$, 즉 $2^{log_2X + 1}$와 같습니다.
따라서, 코드에서의 세그먼트 트리의 각 정보들(깊이 / 노드의 수 / BaseNode의 개수)은 다음과 같이 정의됩니다.
segmentDepth = ceil(log2(n)) + 1 #세그먼트 트리의 총 깊이
segmentTree = [1] * pow(2,segmentDepth) #세그먼트 트리의 각 노드의 값을 저장하는 배열
segmentBases = pow(2,segmentDepth - 1) #세그먼트 트리의 BaseNode의 개수(X)
2. 세그먼트 트리의 갱신
세그먼트 트리의 갱신을 위해서는 자신이 갱신하고자 하는 X번째 값, 즉 BaseNode에서의 X번 노드가 세그먼트 트리에서 실제로 어느 인덱스에 위치해있는지 알 필요가 있습니다.
완전 이진 트리이므로, 이는 간단하게 구할 수 있습니다.
세그먼트 트리에서, X번 BaseNode(주어진 배열에서의 X번째 값)의 인덱스 $Index_X$
$Index_X = TreeSize - BaseNode + X$
위와 같이 $Index_X$를 구하고 나면,
해당 $Index_X$번 노드를 범위에 포함하는 노드, 즉 부모 노드에 대해서 트리를 거슬러 올라가며 값을
갱신해주면 됩니다.
def updateValue(targetNode : int, value : int): #갱신하면서 올라가기
targetNode = getNthNodes(targetNode)
#print(targetNode,value)
diff = segmentTreeNodes[targetNode] - value #
while targetNode:
segmentTreeNodes[targetNode] -= diff
targetNode //= 2 #위로 올라가면서 부모값 갱신
3. 세그먼트 트리의 쿼리 처리
세그먼트 트리에서, 원하는 구간의 합을 구하는 함수 GetAreaSum의 구현은 다음과 같습니다.
def getAreaSum(node : int, start : int, end : int, left : int, right : int) -> int:
1) 두 구간이 일치하는 경우
# 해당 노드의 세그먼트 값을 반환 및 함수 종료
areaSum = 0
2) left < mid 인경우
#현재 노드의 왼쪽 노드, 즉 구간 [start,mid)에 구하고자 하는 구간 [left,min(right,mid))이 포함 되어 있음
#따라서, 상술한 구간의 합을 구하기 위해, getAreaSum(node * 2, start, mid, left, min(mid,right))를 호출, 그 결과값을 areaSum에 더함
3) mid < right 인 경우
# 현재 노드의 오른쪽 노드, 즉 구간 [mid,end)에 구하고자 하는 구간 [max(left,mid),right)이 포함 되어 있음
# 따라서, 상술한 구간의 합을 구하기 위해, getAreaSum(node * 2 + 1, mid, end, max(left,mid), right)를 호출, 그 결과값을 areaSum에 더함
return areaSum
위에서 설명한 내용을 코드로 구현하면 다음과 같습니다.
import sys
from math import log2,ceil
input = sys.stdin.readline
""" 세그먼트 트리 출력하기
def displaySegmentTree():
index = 1
dis = 1
while index < len(segmentTreeNodes):
sepa = segBaseNodes // dis - 1
print(" " * sepa,end = "")
print(*segmentTreeNodes[index: index + dis],sep=" " * (sepa * 2 + 1))
index += dis
dis *= 2
"""
""" Query 함수 정의
def getAreaSum(node : int, start : int, end : int, left : int, right : int) -> int:
1) 두 구간이 일치하는 경우
# 해당 노드의 세그먼트 값을 반환 및 함수 종료
areaSum = 0
2) left < mid 인경우
#현재 노드의 왼쪽 노드, 즉 구간 [start,mid)에 구하고자 하는 구간 [left,min(right,mid))이 포함 되어 있음
#따라서, 상술한 구간의 합을 구하기 위해, getAreaSum(node * 2, start, mid, left, min(mid,right))를 호출, 그 결과값을 areaSum에 더함
3) mid < right 인 경우
# 현재 노드의 오른쪽 노드, 즉 구간 [mid,end)에 구하고자 하는 구간 [max(left,mid),right)이 포함 되어 있음
# 따라서, 상술한 구간의 합을 구하기 위해, getAreaSum(node * 2 + 1, mid, end, max(left,mid), right)를 호출, 그 결과값을 areaSum에 더함
return areaSum
"""
n,m,k = map(int,input().split())
segmentDepth = ceil(log2(n)) + 1
segmentTreeNodes = [0] * (pow(2,segmentDepth))
segBaseNodes = pow(2,segmentDepth-1)
def updateValue(targetNode : int, value : int): #갱신하면서 올라가기
targetNode = len(segmentTreeNodes) - segBaseNodes + targetNode
diff = segmentTreeNodes[targetNode] - value #
while targetNode:
segmentTreeNodes[targetNode] -= diff
targetNode //= 2 #위로 올라가면서 부모값 갱신
def getSegmentSum(wantLeft : int, wantRight : int) -> int:
return getAreaSum(1,0,segBaseNodes,wantLeft,wantRight + 1)
def getAreaSum(node : int, start: int, end: int, wantLeft: int, wantRight: int) -> int:
#원하는 영역의 합 구하기
#print(node,start,end,wantLeft,wantRight)
mid = (start + end) // 2
if (start,end) == (wantLeft,wantRight): #영역이 동일하면 해당 세그먼트 값 리턴하기
return segmentTreeNodes[node]
areaSum = 0
if wantLeft < mid:
areaSum += getAreaSum(node * 2, start, mid, wantLeft, min(wantRight, mid))
if wantRight > mid:
areaSum += getAreaSum(node * 2 + 1, mid, end, max(mid,wantLeft), wantRight)
return areaSum
numberList = [int(input()) for _ in range(n)]
for i in range(len(numberList)):
updateValue(i,numberList[i])
for i in range(m + k):
cmd,a,b = map(int,input().split())
if cmd == 1: #노드 업데이트
updateValue(a-1, b)
else: #구간 [a,b]를 구하는 쿼리
print(getSegmentSum(a-1,b-1))