與普通線段樹並無其他區別,只不過儲存的資訊是每個值出現的次數罷了
理解圖
import sys
input = lambda: sys.stdin.readline()
class Tree:
def __init__(self, N):
self.cnt = [0 for _ in range(N)]
def update(self, root, l, r, x, cnt):
if l == r:
self.cnt[root] += cnt
return
mid = l + r >> 1
if x <= mid:
self.update(root << 1, l, mid, x, cnt)
else:
self.update(root << 1 | 1, mid + 1, r, x, cnt)
self.cnt[root] = self.cnt[root << 1] + self.cnt[root << 1 | 1]
# 值域在[ql,qr]的數出現的次數
def query(self, root, l, r, ql, qr):
if ql > r or qr < l:
return 0
if ql <= l and qr >= r:
return self.cnt[root]
mid = l + r >> 1
return self.query(root << 1, l, mid, ql, qr) + self.query(root << 1 | 1, mid + 1, r, ql, qr)
def find_kth(self, root, l, r, k):
if l == r:
return l
mid = l + r >> 1
if k <= self.cnt[root << 1]:
return self.find_kth(root << 1, l, mid,k)
else:
return self.find_kth(root << 1 | 1, mid + 1, r,k - self.cnt[root << 1])
# 當值域過大時,可以採用離散化的方法!
N = 10
mytree = Tree(N * 4)
lst = [1, 2, 3, 4, 5, 5, 5, 5, 5]
for x in lst:
mytree.update(1, 1, N, x, 1)
for i in range(1, N + 1):
print(mytree.query(1, 1, N, i, i), end=' ')
print()
# 1 1 1 1 5 0 0 0 0 0
print(mytree.find_kth(1, 1, N, 4))
# 4