线段树
视频讲解
🎥 视频讲解1
🎥 视频讲解2
例题1
建树
建树+修改
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+10;
int a[N];
ll sum[N<<2];
int n,q;
void update(int id){
sum[id]=sum[id<<1]+sum[id<<1|1];
}
void build(int id,int l,int r){
if(l==r){
sum[id]=a[l];
return;
}
int mid=(l+r)>>1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
update(id);
}
void modify(int id,int l,int r,int pos,int x){
if(l==r){
sum[id]+=x;
return;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(id<<1,l,mid,pos,x);
else modify(id<<1|1,mid+1,r,pos,x);
update(id);
}
ll query(int id,int l,int r,int ql,int qr){
if(l==ql&&r==qr) return sum[id];
int mid=(l+r)>>1;
if(qr<=mid) return query(id<<1,l,mid,ql,qr);
else if(ql>mid) return query(id<<1|1,mid+1,r,ql,qr);
else return query(id<<1,l,mid,ql,mid)+query(id<<1|1,mid+1,r,mid+1,qr);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
for(int i=1;i<=q;i++){
int op,x,y;
cin>>op>>x>>y;
if(op==1) modify(1,1,n,x,y);
else cout<<query(1,1,n,x,y)<<"\n";
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main{
static final int N=200000+10;
static int[] a=new int[N];
static long[] sum=new long[N<<2];
static int n,q;
static void update(int id){
sum[id]=sum[id<<1]+sum[id<<1|1];
}
static void build(int id,int l,int r){
if(l==r){
sum[id]=a[l];
return;
}
int mid=(l+r)>>1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
update(id);
}
static void modify(int id,int l,int r,int pos,int x){
if(l==r){
sum[id]+=x;
return;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(id<<1,l,mid,pos,x);
else modify(id<<1|1,mid+1,r,pos,x);
update(id);
}
static long query(int id,int l,int r,int ql,int qr){
if(l==ql&&r==qr) return sum[id];
int mid=(l+r)>>1;
if(qr<=mid) return query(id<<1,l,mid,ql,qr);
else if(ql>mid) return query(id<<1|1,mid+1,r,ql,qr);
else return query(id<<1,l,mid,ql,mid)+query(id<<1|1,mid+1,r,mid+1,qr);
}
public static void main(String[] args)throws Exception{
BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st=new StringTokenizer(br.readLine());
n=Integer.parseInt(st.nextToken());
q=Integer.parseInt(st.nextToken());
st=new StringTokenizer(br.readLine());
for(int i=1;i<=n;i++) a[i]=Integer.parseInt(st.nextToken());
build(1,1,n);
StringBuilder sb=new StringBuilder();
for(int i=1;i<=q;i++){
st=new StringTokenizer(br.readLine());
int op=Integer.parseInt(st.nextToken());
int x=Integer.parseInt(st.nextToken());
int y=Integer.parseInt(st.nextToken());
if(op==1) modify(1,1,n,x,y);
else sb.append(query(1,1,n,x,y)).append('\n');
}
System.out.print(sb.toString());
}
}
import sys
input=sys.stdin.readline
N=200000+10
a=[0]*N
sumv=[0]*(N<<2)
def update(id):
sumv[id]=sumv[id<<1]+sumv[id<<1|1]
def build(id,l,r):
if l==r:
sumv[id]=a[l]
return
mid=(l+r)//2
build(id<<1,l,mid)
build(id<<1|1,mid+1,r)
update(id)
def modify(id,l,r,pos,x):
if l==r:
sumv[id]+=x
return
mid=(l+r)//2
if pos<=mid:
modify(id<<1,l,mid,pos,x)
else:
modify(id<<1|1,mid+1,r,pos,x)
update(id)
def query(id,l,r,ql,qr):
if l==ql and r==qr:
return sumv[id]
mid=(l+r)//2
if qr<=mid:
return query(id<<1,l,mid,ql,qr)
elif ql>mid:
return query(id<<1|1,mid+1,r,ql,qr)
else:
return query(id<<1,l,mid,ql,mid)+query(id<<1|1,mid+1,r,mid+1,qr)
n,q=map(int,input().split())
arr=list(map(int,input().split()))
for i in range(1,n+1):
a[i]=arr[i-1]
build(1,1,n)
for _ in range(q):
op,x,y=map(int,input().split())
if op==1:
modify(1,1,n,x,y)
else:
print(query(1,1,n,x,y))
例题2
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
struct node{
int minv,cnt;
}seg[N<<2];
int n,q,a[N];
node merge(const node &l,const node &r){
if(l.minv<r.minv) return {l.minv,l.cnt};
else if(l.minv>r.minv) return {r.minv,r.cnt};
else return {l.minv,l.cnt+r.cnt};
}
void update(int id){
seg[id]=merge(seg[id<<1],seg[id<<1|1]);
}
void build(int id,int l,int r){
if(l==r){
seg[id]={a[l],1};
return;
}
int mid=(l+r)>>1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
update(id);
}
void modify(int id,int l,int r,int pos,int x){
if(l==r){
seg[id]={x,1};
return;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(id<<1,l,mid,pos,x);
else modify(id<<1|1,mid+1,r,pos,x);
update(id);
}
node query(int id,int l,int r,int ql,int qr){
if(l==ql&&r==qr) return seg[id];
int mid=(l+r)>>1;
if(qr<=mid) return query(id<<1,l,mid,ql,qr);
else if(ql>mid) return query(id<<1|1,mid+1,r,ql,qr);
return merge(query(id<<1,l,mid,ql,mid),query(id<<1|1,mid+1,r,mid+1,qr));
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
for(int i=1;i<=q;i++){
int op;cin>>op;
if(op==1){
int x,d;cin>>x>>d;
modify(1,1,n,x,d);
}else{
int l,r;cin>>l>>r;
node ans=query(1,1,n,l,r);
cout<<ans.minv<<" "<<ans.cnt<<"\n";
}
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final int N = 200005;
static class Node {
int minv, cnt;
Node(int m, int c) { minv = m; cnt = c; }
}
static Node[] seg = new Node[N << 2];
static int[] a = new int[N];
static int n, q;
static Node merge(Node l, Node r) {
if (l == null) return r;
if (r == null) return l;
if (l.minv < r.minv) return new Node(l.minv, l.cnt);
else if (l.minv > r.minv) return new Node(r.minv, r.cnt);
else return new Node(l.minv, l.cnt + r.cnt);
}
static void update(int id) {
seg[id] = merge(seg[id << 1], seg[id << 1 | 1]);
}
static void build(int id, int l, int r) {
if (l == r) {
seg[id] = new Node(a[l], 1);
return;
}
int mid = (l + r) >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
update(id);
}
static void modify(int id, int l, int r, int pos, int x) {
if (l == r) {
seg[id] = new Node(x, 1);
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) modify(id << 1, l, mid, pos, x);
else modify(id << 1 | 1, mid + 1, r, pos, x);
update(id);
}
static Node query(int id, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return seg[id];
int mid = (l + r) >> 1;
Node res = null;
if (ql <= mid) res = merge(res, query(id << 1, l, mid, ql, qr));
if (qr > mid) res = merge(res, query(id << 1 | 1, mid + 1, r, ql, qr));
return res;
}
public static void main(String[] args) throws Exception {
// 使用 FastReader 防止读取瓶颈
FastReader fr = new FastReader(System.in);
n = fr.nextInt();
q = fr.nextInt();
for (int i = 1; i <= n; i++) a[i] = fr.nextInt();
build(1, 1, n);
PrintWriter out = new PrintWriter(System.out);
while (q-- > 0) {
int op = fr.nextInt();
if (op == 1) {
int x = fr.nextInt();
int d = fr.nextInt();
modify(1, 1, n, x, d);
} else {
int l = fr.nextInt();
int r = fr.nextInt();
Node ans = query(1, 1, n, l, r);
out.println(ans.minv + " " + ans.cnt);
}
}
out.flush();
}
static class FastReader {
StringTokenizer st;
BufferedReader br;
public FastReader(InputStream s) { br = new BufferedReader(new InputStreamReader(s)); }
public String next() throws Exception {
while (st == null || !st.hasMoreElements()) st = new StringTokenizer(br.readLine());
return st.nextToken();
}
public int nextInt() throws Exception { return Integer.parseInt(next()); }
}
}
import sys
input=sys.stdin.readline
N=200000+10
a=[0]*N
seg=[(0,0)]*(N<<2)
def merge(l,r):
if l[0]<r[0]:
return (l[0],l[1])
elif l[0]>r[0]:
return (r[0],r[1])
else:
return (l[0],l[1]+r[1])
def update(id):
seg[id]=merge(seg[id<<1],seg[id<<1|1])
def build(id,l,r):
if l==r:
seg[id]=(a[l],1)
return
mid=(l+r)//2
build(id<<1,l,mid)
build(id<<1|1,mid+1,r)
update(id)
def modify(id,l,r,pos,x):
if l==r:
seg[id]=(x,1)
return
mid=(l+r)//2
if pos<=mid:
modify(id<<1,l,mid,pos,x)
else:
modify(id<<1|1,mid+1,r,pos,x)
update(id)
def query(id,l,r,ql,qr):
if l==ql and r==qr:
return seg[id]
mid=(l+r)//2
if qr<=mid:
return query(id<<1,l,mid,ql,qr)
elif ql>mid:
return query(id<<1|1,mid+1,r,ql,qr)
else:
return merge(query(id<<1,l,mid,ql,mid),query(id<<1|1,mid+1,r,mid+1,qr))
n,q=map(int,input().split())
arr=list(map(int,input().split()))
for i in range(1,n+1):
a[i]=arr[i-1]
build(1,1,n)
for _ in range(q):
op,x,y=map(int,input().split())
if op==1:
modify(1,1,n,x,y)
else:
ans=query(1,1,n,x,y)
print(ans[0],end=" ")
print(ans[1])
例题3
代码实现
参考实现
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
struct node{
int d1,d2;
}seg[N<<2];
int a[N],n,q;
node merge(const node &l,const node &r){
node res={l.d1,l.d2};
if(r.d1>=res.d1) res={r.d1,l.d1};
else if(r.d1>res.d2) res.d2=r.d1;
if(r.d2>=res.d2) res={res.d1,r.d2};
return res;
}
void update(int id){
seg[id]=merge(seg[id<<1],seg[id<<1|1]);
}
void build(int id,int l,int r){
if(l==r){
seg[id]={a[l],-1};
return;
}
int mid=(l+r)>>1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
update(id);
}
void modify(int id,int l,int r,int pos,int x){
if(l==r){
seg[id]={x,-1};
return;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(id<<1,l,mid,pos,x);
else modify(id<<1|1,mid+1,r,pos,x);
update(id);
}
node query(int id,int l,int r,int ql,int qr){
if(l==ql&&r==qr) return seg[id];
int mid=(l+r)>>1;
if(qr<=mid) return query(id<<1,l,mid,ql,qr);
else if(ql>mid) return query(id<<1|1,mid+1,r,ql,qr);
return merge(query(id<<1,l,mid,ql,mid),query(id<<1|1,mid+1,r,mid+1,qr));
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>n>>q;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
for(int i=1;i<=q;i++){
int op,x,y;
cin>>op>>x>>y;
if(op==1) modify(1,1,n,x,y);
else{
node ans=query(1,1,n,x,y);
cout<<ans.d1<<" "<<ans.d2<<"\n";
}
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main{
static final int N=200000+10;
static class Node{
int d1,d2;
Node(int a,int b){d1=a;d2=b;}
}
static Node[] seg=new Node[N<<2];
static int[] a=new int[N];
static int n,q;
static Node merge(Node l,Node r){
Node res=new Node(l.d1,l.d2);
if(r.d1>=res.d1) res=new Node(r.d1,l.d1);
else if(r.d1>res.d2) res.d2=r.d1;
if(r.d2>=res.d2) res=new Node(res.d1,r.d2);
return res;
}
static void update(int id){
seg[id]=merge(seg[id<<1],seg[id<<1|1]);
}
static void build(int id,int l,int r){
if(l==r){
seg[id]=new Node(a[l],-1);
return;
}
int mid=(l+r)>>1;
build(id<<1,l,mid);
build(id<<1|1,mid+1,r);
update(id);
}
static void modify(int id,int l,int r,int pos,int x){
if(l==r){
seg[id]=new Node(x,-1);
return;
}
int mid=(l+r)>>1;
if(pos<=mid) modify(id<<1,l,mid,pos,x);
else modify(id<<1|1,mid+1,r,pos,x);
update(id);
}
static Node query(int id,int l,int r,int ql,int qr){
if(l==ql&&r==qr) return seg[id];
int mid=(l+r)>>1;
if(qr<=mid) return query(id<<1,l,mid,ql,qr);
else if(ql>mid) return query(id<<1|1,mid+1,r,ql,qr);
return merge(query(id<<1,l,mid,ql,mid),query(id<<1|1,mid+1,r,mid+1,qr));
}
public static void main(String[] args)throws Exception{
BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st=new StringTokenizer(br.readLine());
n=Integer.parseInt(st.nextToken());
q=Integer.parseInt(st.nextToken());
st=new StringTokenizer(br.readLine());
for(int i=1;i<=n;i++) a[i]=Integer.parseInt(st.nextToken());
build(1,1,n);
StringBuilder sb=new StringBuilder();
for(int i=1;i<=q;i++){
st=new StringTokenizer(br.readLine());
int op=Integer.parseInt(st.nextToken());
int x=Integer.parseInt(st.nextToken());
int y=Integer.parseInt(st.nextToken());
if(op==1) modify(1,1,n,x,y);
else{
Node ans=query(1,1,n,x,y);
sb.append(ans.d1).append(" ").append(ans.d2).append('\n');
}
}
System.out.print(sb.toString());
}
}
import sys
# 提高递归深度限制,防止深层线段树导致溢出
sys.setrecursionlimit(200000)
def solve():
# 使用快读提高效率
input = sys.stdin.read().split()
if not input:
return
idx = 0
n = int(input[idx]); idx += 1
q = int(input[idx]); idx += 1
a = [0] * (n + 1)
for i in range(1, n + 1):
a[i] = int(input[idx]); idx += 1
# 线段树节点数组,每个元素存储 [最大值, 次大值]
# 使用平铺列表比对象列表更快
tree_d1 = [0] * (4 * n + 1)
tree_d2 = [-1] * (4 * n + 1)
def merge(l1, l2, r1, r2):
"""
合并两个节点的两个最大值
l1, l2: 左子树的最大和次大
r1, r2: 右子树的最大和次大
"""
# 取四个数中最大的两个
vals = sorted([l1, l2, r1, r2], reverse=True)
return vals[0], vals[1]
def build(node_id, l, r):
if l == r:
tree_d1[node_id] = a[l]
tree_d2[node_id] = -1
return
mid = (l + r) // 2
build(node_id * 2, l, mid)
build(node_id * 2 + 1, mid + 1, r)
# 向上更新 (Push up)
tree_d1[node_id], tree_d2[node_id] = merge(
tree_d1[node_id * 2], tree_d2[node_id * 2],
tree_d1[node_id * 2 + 1], tree_d2[node_id * 2 + 1]
)
def modify(node_id, l, r, pos, val):
if l == r:
tree_d1[node_id] = val
tree_d2[node_id] = -1
return
mid = (l + r) // 2
if pos <= mid:
modify(node_id * 2, l, mid, pos, val)
else:
modify(node_id * 2 + 1, mid + 1, r, pos, val)
tree_d1[node_id], tree_d2[node_id] = merge(
tree_d1[node_id * 2], tree_d2[node_id * 2],
tree_d1[node_id * 2 + 1], tree_d2[node_id * 2 + 1]
)
def query(node_id, l, r, ql, qr):
if ql <= l and r <= qr:
return tree_d1[node_id], tree_d2[node_id]
mid = (l + r) // 2
if qr <= mid:
return query(node_id * 2, l, mid, ql, qr)
elif ql > mid:
return query(node_id * 2 + 1, mid + 1, r, ql, qr)
else:
l_max, l_sec = query(node_id * 2, l, mid, ql, mid)
r_max, r_sec = query(node_id * 2 + 1, mid + 1, r, mid + 1, qr)
return merge(l_max, l_sec, r_max, r_sec)
# 初始化建树
build(1, 1, n)
# 处理查询
results = []
for _ in range(q):
op = int(input[idx]); idx += 1
x = int(input[idx]); idx += 1
y = int(input[idx]); idx += 1
if op == 1:
modify(1, 1, n, x, y)
else:
res_max, res_sec = query(1, 1, n, x, y)
results.append(f"{res_max} {res_sec}")
# 批量输出提高效率
sys.stdout.write("\n".join(results) + "\n")
if __name__ == "__main__":
solve()