-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSegmentTree.py
More file actions
202 lines (184 loc) · 5.21 KB
/
SegmentTree.py
File metadata and controls
202 lines (184 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
'''
来自于wls的线段树的模板
'''
'''
part one
1. 建树
2. 更新单点值
3. 区间查询
更新与查询时间复杂度为O(log n)
'''
class SegmentTree(object):
"""docstring for SegmentTree"""
def __init__(self, nums):
super(SegmentTree, self).__init__()
self.nums = nums
self.n = len(self.nums)
self.t = [0] * (4*n+1)
def BuildTree(self,k,l,r):
# k表示在线段树中的下标, l,r 分别表示区间的端点
if l == r:
self.t[k] = self.nums[l]
return
m = (l+r) >> 1
self.BuildTree(k+k,l,m)
self.BuildTree(k+k+1,m+1,r)
self.t[k] = self.t[k+k] + self.t[k+k+1]
def UpdateValue(self,k,l,r,idx,v):
# k表示线段树位置下标,l,r分别表示区间端点,idx表示原数组中的位置,v表示需要更新出来的值
self.t[k] += v
if (l == r):
return
m = (l+r) >> 1
if x <= m:
# 向左边递归
self.UpdateValue(k+k,l,m,idx,v)
else:
self.UpdateValue(k+k+1,m+1,r,idx,v)
def QueryRange(self,k,l,r,s,t):
# k表示在线段树位置的下标,l,r表示区间端点, s,t表示为需要问询的区间的端点
if (l == s) and (r == t):
return self.t[k]
m = (l+r) >> 1
if t <= m:
return self.QueryRange(k+k,l,m,s,t)
else:
if s > m:
return self.QueryRange(k+k+1,m+1,r,s,t)
else:
return self.QueryRange(k+k,l,m,s,m) + self.QueryRange(k+k+1,m+1,r,m+1,t)
'''
part two
支持区间修改以及区间查询
log n的区间修改,我们需要使用一个新的变量v用于进行标记
log n的区间查询,考虑一个区间的查询,分为两个部分上面的v的部分以及下面整体的部分
'''
class SegmentTree2(object):
"""docstring for SegmentTree2"""
def __init__(self, nums):
super(SegmentTree2, self).__init__()
self.nums = nums
self.n = len(self.nums) - 1
self.t = [0] * (4*self.n+1)
self.v = [0] * (4*self.n+1) #self.v[k]表示当前k位置的区间每个值都需要更新的值
def BuildTree(self,k,l,r):
# the same as the last class
if (l == r):
self.t[k] = self.nums[l]
return
m = (l+r) // 2
self.BuildTree(k+k,l,m)
self.BuildTree(k+k+1,m+1,r)
self.t[k] = self.t[k+k] + self.t[k+k+1]
def UpdateRange(self,k,l,r,x,y,z):
# 在x到y的区间更新为z
if (l == x) and (r == y):
self.v[k] += z
return
self.t[k] += (y-x+1) * z
m = (l+r) >> 1
if y <= m:
self.UpdateRange(k+k,l,m,x,y,z)
else:
if (x > m):
self.UpdateRange(k+k+1,m+1,r,x,y,z)
else:
self.UpdateRange(k+k,l,m,x,m,z)
self.UpdateRange(k+k+1,m+1,r,m+1,y,z)
def QueryRange(self,k,l,r,s,t,p):
# 问询s到t的区间的值
# p表示为走到当前的点v的值累计和为多少
p += self.v[k]
if (l == s) and (r == t):
return self.t[k] + (r-l+1) * p
m = (l+r) >> 1
if t <= m:
return self.QueryRange(k+k,l,m,s,t,p)
else:
if s > m:
return self.QueryRange(k+k+1,m+1,r,s,t,p)
else:
return self.QueryRange(k+k,l,m,s,m,p) + self.QueryRange(k+k+1,m+1,r,m+1,t,p)
'''
part three
标记下放
原本我们分为两个部分:上面路径的v以及下面路径的f
标记希望我们只求一个部分
我们考虑将上面的路径将v值往下穿是的前面的v都为0
父亲节点v清空,将两边的区间的v添加父亲节点的v
但是需要考虑父亲节点的f值:解决方案:递归结束的时候反过来更新f值
'''
class SegmentTree3(object):
"""docstring for SegmentTree3"""
def __init__(self, nums):
super(SegmentTree3, self).__init__()
self.nums = nums
self.n = len(self.nums)
self.t = [0] * (4*self.n+1)
self.v = [0] * (4*self.n+1)
def BuildTree(self,k,l,r):
if l == r:
self.t[k] = self.nums[l]
return
m = (l+r) >> 1
self.BuildTree(k+k,l,m)
self.BuildTree(k+k+1,m+1,r)
self.t[k] = self.t[k+k] + self.t[k+k+1]
def UpdateRange(self,k,l,r,x,y,z):
# 修改区间x到y的值为z
if (l == x) and (r == y):
self.v[k] += z
return
# 标记下传
if self.v[k] != 0:
self.v[k+k] += self.v[k]
self.v[k+k+1] += self.v[k]
self.v[k] = 0
m = (l+r) >> 1
if y <= m:
self.UpdateRange(k+k,l,m,x,y,z)
else:
if x > m:
self.UpdateRange(k+k+1,m+1,r,x,y,z)
else:
self.UpdateRange(k+k,l,m,x,m,z)
self.UpdateRange(k+k+1,m+1,r,m+1,y,z)
# 值更新
self.t[k] = self.t[k+k] + self.v[k+k] * (m-l+1) + self.t[k+k+1] + self.v[k+k+1] * (r-m)
def QueryRange(self,k,l,r,x,y):
if (l == x) and (r == y):
return self.t[k] + self.v[k] * (r-l+1)
# 标记下传
if self.v[k] != 0:
self.v[k+k] += self.v[k]
self.v[k+k+1] += self.v[k]
self.v[k] = 0
m = (l+r) >> 1
ans = 0
if y <= m:
ans = self.QueryRange(k+k,l,m,x,y)
else:
if x > m:
ans = self.QueryRange(k+k+1,m+1,r,x,y)
else:
ans = self.QueryRange(k+k,l,m,x,m) + self.QueryRange(k+k+1,m+1,r,m+1,y)
# 值更新
self.t[k] = self.t[k+k] + self.v[k+k] * (m-l+1) + self.t[k+k+1] + self.v[k+k+1] * (r-m)
return ans
if __name__ == "__main__":
n,m = list(map(int,input().split()))
nums = [0] + list(map(int,input().split()))
st = SegmentTree3(nums)
n = len(nums) - 1
st.BuildTree(1,1,n)
qs = []
for i in range(m):
cur = list(map(int,input().split()))
qs.append(cur)
for cur in qs:
if cur[0] == 1:
# 更新操作
st.UpdateRange(1,1,n,cur[1],cur[2],cur[3])
elif cur[0] == 2:
# 查询操作
print(st.QueryRange(1,1,n,cur[1],cur[2]))