0%

算法基础02

  • 链表与邻接表
  • 栈与队列
  • kmp
  • Trie
  • 并查集
  • 哈希表
  • 树状数组
  • 线段树

链表

数组模拟单链表

邻接表(存储数、图)

head -> null

head -> o0 -> o1 -> o2 -> null

int e[N] 表示值

int ne[N] 表示next指针(指向next的下标)为空:ne[n-1] = -1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import java.util.Scanner;

//单链表
public class chain_list_1 {
static int N = 100010;
//head表示头结点的下标
static int head;
//e[i]表示节点i的值
static int[] e = new int[N];
//ne[i]表示节点i的next指针是多少
static int[] ne = new int[N];
//idx存储当前已经用到了哪个节点
static int idx;

//初始化
static void init() {
head = -1;
idx = 0;
}

//将x插到头结点
static void add_to_head(int x) {
e[idx] = x;
ne[idx] = head;
head = idx;
idx++;
}

//将x插到下标是k的点后面
static void add(int k, int x) {
e[idx] = x;
ne[idx] = ne[k];
ne[k] = idx;
idx++;
}

//将下标是k的点后面的点删掉
static void remove(int k) {
ne[k] = ne[ne[k]];
}

public static void main(String[] args) {
init();
Scanner sc = new Scanner(System.in);
int m = sc.nextInt();
while (m-- != 0) {
String op = sc.next();
if (op.charAt(0) == 'H') {
int x = sc.nextInt();
add_to_head(x);
} else if (op.charAt(0) == 'D') {
int k = sc.nextInt();
if (k == 0) head = ne[head];
else remove(k - 1);
} else {
int k = sc.nextInt();
int x = sc.nextInt();
add(k - 1, x);
}
}
for (int i = head; i != -1; i = ne[i]) {
System.out.printf("%d ", e[i]);
}
}
}

数组模拟双链表

优化某些问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
public class chain_list_2 {
static int N = 100010;
//e[i]表示节点i的值
static int[] e = new int[N];
//l[i]表示节点i左边指向点的下标
static int[] l = new int[N];
//r[i]表示节点i右边指向点的下标
static int[] r = new int[N];

//idx存储当前已经用到了哪个节点
static int idx;

void init() {
//0表示左端点,l表示右端点
r[0] = 1;
l[1] = 0;
idx = 2;
}

//在下标是k的点的右边,插入x
void add(int k, int x) {
e[idx] = x;
r[idx] = r[k];
l[idx] = k;
l[r[k]] = idx;
r[k] = idx;//注意两句不能写反
idx++;
}

//删除第k个点
void remove(int k) {
r[l[k]] = r[k];
l[r[k]] = l[k];
}

}

模拟栈

模拟队列

单调栈

给定一个序列,求这个序列中每个数左边(右边)离其最近的比它小的数的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import java.util.*;

public class Main{
static int m;
static int N = 100010;
static int[] stack = new int[N];
static int tt = 0;
public static void main(String[] args){
Scanner sc = new Scanner(System.in);
m = sc.nextInt();
while(m -- > 0){
int x;
x = sc.nextInt();
while(tt != 0 && stack[tt] >= x) tt --;
if(tt != 0) System.out.print(stack[tt] + " ");
else System.out.print("-1 ");
stack[++ tt] = x;
}
}
}

单调队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

public class Main {
static int N = 1000010;
static int[] a = new int[N];

static int[] q = new int[N];


public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String[] str = br.readLine().split(" ");
int n = Integer.parseInt(str[0]);
int k = Integer.parseInt(str[1]);
str = br.readLine().split(" ");
for (int i = 0; i < n; i++) a[i] = Integer.parseInt(str[i]);
int hh = 0, tt = -1;
for (int i = 0; i < n; i++) {
if (hh <= tt && i - k + 1 > q[hh]) hh++;

while (hh <= tt && a[q[tt]] >= a[i]) tt--;//单增队列

q[++tt] = i;

if (i >= k - 1) System.out.print(a[q[hh]] + " ");
}
System.out.println();
hh = 0;
tt = -1;
for (int i = 0; i < n; i++) {
if (hh <= tt && i - k + 1 > q[hh]) hh++;
while (hh <= tt && a[q[tt]] <= a[i]) tt--;//单减队列
q[++tt] = i;
if (i >= k - 1) System.out.print(a[q[hh]] + " ");
}
System.out.println();
}
}

kmp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import java.io.*;
public class Main{
static int N = 100010;
static int ne[] = new int[N];
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
Integer n = Integer.parseInt(br.readLine());
String s1 = " " + br.readLine();
Integer m = Integer.parseInt(br.readLine());
String s2 = " " + br.readLine();
char[] a1 = s1.toCharArray();
char[] a2 = s2.toCharArray();
/**
* ne[]:存储一个字符串以每个位置为结尾的‘可匹配最长前后缀’的长度。
* 构建ne[]数组:
* 1,初始化ne[1] = 0,i从2开始。
* 2,若匹配,s[i]=s[j+1]说明1~j+1是i的可匹配最长后缀,ne[i] = ++j;
* 3,若不匹配,则从j的最长前缀位置+1的位置继续与i比较
* (因为i-1和j拥有相同的最长前后缀,我们拿j的前缀去对齐i-1的后缀),
* 即令j = ne[j],继续比较j+1与i,若匹配转->>2
* 4,若一直得不到匹配j最终会降到0,也就是i的‘可匹配最长前后缀’的长度
* 要从零开始重新计算
*/
for(int i = 2,j = 0;i <= n ;i++) {
while(j!=0&&a1[i]!=a1[j+1]) j = ne[j];
if(a1[i]==a1[j+1]) j++;
ne[i] = j;
}
/**
* 匹配两个字符串:
* 1,从i=1的位置开始逐个匹配,利用ne[]数组减少比较次数
* 2,若i与j+1的位置不匹配(已知1~j匹配i-j~i-1),
* j跳回ne[j]继续比较(因为1~j匹配i-j~i-1,所以1~ne[j]也能匹配到i-ne[j]~i-1)
* 3,若匹配则j++,直到j==n能确定匹配成功
* 4,成功后依然j = ne[j],就是把这次成功当成失败,继续匹配下一个位置
*/
for(int i = 1,j = 0; i <= m;i++) {
while(j!=0&&a2[i]!=a1[j+1]) j = ne[j];
if(a2[i]==a1[j+1]) j++;
if(j==n) {
j = ne[j];
bw.write(i-n+" ");
}
}
/**
* 时间复杂度:
* 因为:j最多加m次,再加之前j每次都会减少且最少减一,j>0
* 所以:while循环最多执行m次,若大于m次,j<0矛盾
* 最终答案:O(2m)
*/
bw.flush();
}
}

KMP在求循环节中的运用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import java.io.*;

public class Main{
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static int N = 1000010;
static int[] ne = new int[N];

public static void main(String[] args) throws IOException {
int cnt = 1;
while(true){
int n = Integer.parseInt(br.readLine());
if(n == 0)break;
System.out.println("Test case #" + cnt++);
char[] s = (" " + br.readLine()).toCharArray();
for(int i = 2, j = 0; i <= n; i++){
while(j != 0 && s[i] != s[j + 1])j = ne[j];
if(s[i] == s[j + 1])j++;
ne[i] = j;
}
//通过ne数组对循环节进行检验
for(int i = 1; i <= n; i++){
int t = i - ne[i];//循环节长度
if(i % t == 0 && i / t > 1){
System.out.println(i + " " + i / t);
}
}
System.out.println();
}
}
}

Trie树

用来高效存储和查找字符串集合的数据结构

集合的数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import java.util.Scanner;

public class trie {
static int N = 100010;

static int[][] son = new int[N][26];//存储节点下标(第N个节点是否存在?的儿子,不存在为0,存在存储儿子的idx位置)
static int[] cnt = new int[N];//标记以其为结尾的字符串个数便于查询
static int idx = 0;//保证每个Trie树的新分支(新节点)存储空间不同不冲突

static void insert(String str) {
int p = 0;
for (int i = 0; i < str.length(); i++) {
int u = str.charAt(i) - 'a';
if (son[p][u] == 0) son[p][u] = ++idx;
p = son[p][u];
}
cnt[p]++;
}

static int query(String str) {
int p = 0;
for (int i = 0; i < str.length(); i++) {
int u = str.charAt(i) - 'a';
if (son[p][u] == 0) return 0;
p = son[p][u];
}
return cnt[p];
}

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
for (int i = 0; i < n; i++) {
String op = sc.next();
if (op.equals("I")) insert(sc.next());
else System.out.println(query(sc.next()));
}
}
}

好栗子:

异或数列

在给定的 N 个整数 A1,A2……AN 中选出两个进行 xor(异或)运算,得到的结果最大是多少?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import java.util.Scanner;

public class test {
static int N = 31 * 100010;
static int[][] son = new int[N][2];

static int idx = 0;

static int[] a = new int[N];

static void insert(int x) {
int p = 0;
for (int i = 30; i >= 0; i--) {
int k = x >> i & 1;//位运算取第k位
if (son[p][k] == 0) son[p][k] = ++idx;
p = son[p][k];
}
}

static int query(int x) {
int p = 0;
int res = 0;
for (int i = 30; i >= 0; i--) {
int k = x >> i & 1;
int kp = k == 0 ? 1:0;
if (son[p][kp] != 0) {
p = son[p][kp];
res = res * 2 + kp;
} else {
p = son[p][k];
res = res * 2 + k;
}
}
return res;
}

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int max = 0;
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
insert(a[i]);
int res = query(a[i]);
max = Math.max(max, res ^ a[i]);
}
System.out.println(max);
}
}

并查集

  • 将两个元素合并

  • 询问两个元素是否在一个集合当中

并查集可在近乎O(1)的复杂度情况下完成这两个操作

基本原理:每一个集合用树来表示,树根的编号就是树的编号。每个节点存储它的父节点

  • 问题一:如何判断树根 if(p[x] == x)

  • 问题二:如何求x的集合编号 while(p[x] != x) x = p[x]

  • 问题三:如何合并两个集合 px是x的集合编号 py是y的集合编号 p[x] = y

  • 问题四:统计各集合内元素个数 合并集合时 cnt[y] += cnt[x] 以y为合并集合根节点,返回cnt[find(a)]

优化:路径压缩

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include<iostream>
using namespace std;

const int N = 100010;

int p[N];

int find(int x){//返回x的祖宗节点+路径压缩
if(p[x]!=x)p[x] = find(p[x]);
return p[x];
}

int main(){
int n,m;
scanf("%d%d", &n,&m);
for(int i=0;i<n;i++)p[i] = i;

while(m--){
char op[2];
int a,b;
scanf("%s%d%d", op, &a, &b);
if(op[0] == 'M')p[find(a)] = find(b);
else{
if(find(a) == find(b)) puts("Yes");
else puts("No");
}
}
return 0;
}

java模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import java.util.Scanner;

public class union_set {
static int N = 100010;

static int[] p = new int[N];

static int find(int x){//返回x的祖宗节点+路径压缩
if(p[x]!=x)p[x] = find(p[x]);
return p[x];
}

public static void main(String[] args){
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
for(int i=0;i<n;i++)p[i] = i;

while(m-- != 0){
String op = sc.next();
int a = sc.nextInt();
int b = sc.nextInt();
if(op.equals("M"))p[find(a)] = find(b);
else {
if(find(a) == find(b))System.out.println("Yes");
else System.out.println("No");
}
}
}

}

堆->完全二叉树

小根堆:每个节点的值小于等于其左右儿子节点的值

堆的存储:

x的左儿子:2x

x的右儿子:2x+1

down(x){} O(logn)
up(x){} O(logn)

堆支持的操作

  • 插入一个数 heap[ ++ size] = x; up[size];
  • 求集合当中的最小值 heap[1];
  • 删除最小值 heap[1] = heap[size]; size--;down(1);
  • 删除任意一个元素 heap[k] = heap[size];size--;down(k);up(k);
  • 修改任意一个元素 heap[k] = x; down(k);up(k);
1
2
3
4
5
6
7
8
9
void down(int u){
int t = u;
if(u * 2 <= size && h[u * 2] < h[t]) t = u * 2;
if(u * 2 + 1 <= size && h[u * 2 + 1] < h[t]) t = u * 2 + 1;
if(u!=t){
swap(h[u], h[t]);
down(t);
}
}
1
2
3
4
5
6
void up(int u){
while(u/2 && h[u/2] > h[u]) {
swap(h[u], h[u/2]);
u /= 2;
}
}

堆排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import java.util.Scanner;

public class heap_sort {
static int N = 100010;
static int[] h = new int[N];
static int size = 0;

static void down(int u) {
int k = u;
if (u * 2 <= size && h[u * 2] < h[k]) k = u * 2;
if (u * 2 + 1 <= size && h[u * 2 + 1] < h[k]) k = u * 2 + 1;
if (k != u) {
int t = h[k];
h[k] = h[u];
h[u] = t;
down(k);
}
}

static void up(int u) {
while (u / 2 >= 1 && h[u / 2] > h[u]) {
int t = h[u];
h[u] = h[u / 2];
h[u / 2] = t;
u /= 2;
}
}

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();

for (int i = 1; i <= n; i++) {//习惯堆下标从1开始
h[i] = sc.nextInt();
}
size = n;
for (int i = n / 2; i > 0; i--) down(i);//初始化堆 复杂度O(n)

while (m-- != 0) {
System.out.print(h[1] + " ");
h[1] = h[size];
size--;
down(1);
}
}
}

哈希表

  • 存储结构:开放寻址法、拉链法(链地址法)
  • 字符串哈希方式

开放寻址法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import java.util.Arrays;
import java.util.Scanner;

public class hash {
static int N = 200003;// 质数 一般开数据范围的2~3倍, 这样大概率就没有冲突了
static final int C = 0x3f3f3f3f;
static int[] h = new int[N];

static int find(int x){//存在返回对应下标,不存在返回其应该存储的位置
int k = (x % N + N) % N;
while(h[k] != C && h[k] != x){
k++;
if(k == N) k = 0;
}
return k;
}

public static void main(String[] args){
Arrays.fill(h, C);

Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
while (n-- != 0) {
String s = sc.next();
int x = sc.nextInt();
int k = find(x);
if(s.equals("I"))h[k] = x;
else{
if(h[k] != C)System.out.println("Yes");
else System.out.println("No");
}
}
}
}

拉链法(链地址法)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import java.util.Scanner;

//链地址法(拉链法)
public class hash {
static int N = 100003;//选择质数
static int[] h = new int[N+1];//hash table
static int[] e = new int[N+1];
static int[] ne = new int[N+1];
static int idx = 1;

static void insert(int x) {
int k = (x % N + N) % N;//保证大于等于0
e[idx] = x;
ne[idx] = h[k];
h[k] = idx++;
}

static boolean find(int x) {
int k = (x % N + N) % N;
for(int i = h[k]; i != 0; i = ne[i]){
if(e[i] == x)return true;
}
return false;
}

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
while (n-- != 0) {
String s = sc.next();
if (s.equals("I")) {
insert(sc.nextInt());
} else {
if (find(sc.nextInt())) {
System.out.println("Yes");
} else {
System.out.println("No");
}
}
}
}
}

字符串前缀哈希方式

(字符串哈希) O(n)+O(m)

全称字符串前缀哈希法,把字符串变成一个p进制数字(哈希值),实现不同的字符串映射到不同的数字。

对形如 X1X2X3⋯Xn−1Xn 的字符串,采用字符的ascii 码乘上 P 的次方来计算哈希值。

映射公式 (X1×Pn−1+X2×Pn−2+⋯+Xn−1×P1+Xn×P0)modQ

注意点:

  1. 任意字符不可以映射成0,否则会出现不同的字符串都映射成0的情况,比如A,AA,AAA皆为0
  2. 冲突问题:通过巧妙设置P (131 或 13331) , Q (2^64)的值,一般可以理解为不产生冲突。

问题是比较不同区间的子串是否相同,就转化为对应的哈希值是否相同。

求一个字符串的哈希值就相当于求前缀和,求一个字符串的子串哈希值就相当于求部分和。

前缀和公式 h[i]=h[i-1]×P+s[i] i∈[1,n] h为前缀和数组,s为字符串数组

区间和公式 h[l,r]=h[r]−h[l−1]×P[r−l+1]

区间和公式的理解: ABCDE 与 ABC 的前三个字符值是一样,只差两位,

乘上 P2 把 ABC 变为 ABC00,再用 ABCDE - ABC00 得到 DE 的哈希值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import java.util.Scanner;

public class string_hash {
static int N = 100010;
static int P = 131; //经验值P = 131, 13331

//前缀哈希求完后需要进行模2^64来防止相同的冲突
static long[] p = new long[N];
static long[] h = new long[N];

static long get(int l, int r){//获得l~r区间的哈希值
return h[r] - h[l-1] * p[r-l+1];
}

public static void main(String[] args){
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
String s = sc.next();
p[0] = 1;
for(int i = 1; i <= n; i++){
p[i] = p[i-1] * P;
h[i] = h[i-1] * P + s.charAt(i-1);
}

while(m-- != 0){
int l1 = sc.nextInt();
int r1 = sc.nextInt();
int l2 = sc.nextInt();
int r2 = sc.nextInt();

if(get(l1,r1) == get(l2,r2))System.out.println("Yes");
else System.out.println("No");
}
}
}

树状数组

可解决问题:

  • 快速求前缀和 O(logn)
  • 快速求前缀最大值
  • 修改 O(logn) 查询 O(logn)

本质上解决一类问题的在线做法:单点修改、区间查询

*数组 求前缀和 O(n) 修改 O(1)*

*前缀和数组 查询O(1) 修改O(n)*

c[x] = a[x - lowbit(x) + 1, x] = a(x - lowbit(x), x]

c[x]长度:lowbit(x)

区间描述:以x结尾的,长度是2^k(k为x的最后一位1的位置)的区间

父节点找子节点:(适用于求前缀和)

子节点找父节点(适用于修改操作 子节点变动引起父节点变动)

模板:

1
2
3
static int lowbit(int x){
return x & -x;
}

只能加上一个数,不能完全变成一个数,但可以用x + (-x) + y转化

1
2
3
static void add(int x, int k){ //修改t[x] (并修改受其影响的父结点)
for(int i = x; i <= n; i += lowbit(i))t[i] += k;
}
1
2
3
4
5
6
7
static int sum(int x){ //查询t[x] (由父结点找到所有子结点) sum[1~x]
int res = 0;
for(int i = x; i != 0; i -= lowbit(i)){
res += t[i];
}
return res;
}

初始化(单点加)

1
for(int i = 1; i <= n; i++)add(i, a[i]);

查询a[l ~ r] (求区间和)

1
sum(r) - sum(l - 1)

树状数组下标必须以1开始

例题:楼兰图腾(单点修改,区间查询)

https://www.acwing.com/problem/content/243/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import java.io.*;
import java.util.Arrays;

public class Main{
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static int N = 200010;
static int[] a = new int[N];
static int[] t = new int[N];
static int n;
static int[] lower = new int[N];
static int[] higher = new int[N];

//求二进制最后一位1
static int lowbit(int x){
return x & -x;
}
//修改操作,找到所有唯一的父节点,相当于递归,往上找父节点
static void add(int x, int k){
for(int i = x; i <= n; i += lowbit(i)){
t[i] += k;
}
}
//查询操作,求前缀和操作,根据画出来的树状图,即所有每次减去一个最后一位1
static int sum(int x){
int res = 0;
for(int i = x; i != 0; i -= lowbit(i)){
res += t[i];
}
return res;
}

public static void main(String[] args) throws IOException {
n = Integer.parseInt(br.readLine());
String[] str = br.readLine().split(" ");
for(int i = 1; i <= n; i++){
a[i] = Integer.parseInt(str[i - 1]);
}

for(int i = 1; i <= n; i++){
int y = a[i];//y表示坐标
lower[i] = sum(y - 1);//然后这里求y-1就是因为左边是比他小的数,所以上面up不加的
higher[i] = sum(n) - sum(y);//右边所有比他大的数
add(y, 1);//然后在y这个坐标加上1
}

Arrays.fill(t, 0);//因为需要进行两边操作,所以需要进行清空树状数组

//会爆int,因为每一个点左右两边最坏可能都有n个数
//那就是n方个数,然后执行n次,就是n3方,爆int
long resA = 0, resV = 0;
for(int i = n; i >= 1; i--){//然后将数组翻转过来,重新操作一遍,原理一样
int y = a[i];
resA += (long)lower[i] * sum(y - 1);
resV += (long)higher[i] * (sum(n) - sum(y));
add(y, 1);
}
System.out.println(resV + " " + resA);
}
}

差分逆运算也可以转变成前缀和进行树状数组操作

例题:一个简单的整数问题(区间查询,单点修改)

https://www.acwing.com/problem/content/248/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import java.io.*;

public class Main{
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static int N = 100010;
static int[] a = new int[N];
static long[] b = new long[N];
static long[] tr = new long[N];
static int n;

static int lowbit(int x){
return x & -x;
}

static void add(int x, int k){
for(int i = x; i <= n; i += lowbit(i))tr[i] += k;
}

static long sum(int x){
long res = 0;
for(int i = x; i > 0; i -= lowbit(i))res += tr[i];
return res;
}

public static void main(String[] args) throws IOException {
String[] str = br.readLine().split(" ");
n = Integer.parseInt(str[0]);
int m = Integer.parseInt(str[1]);
str = br.readLine().split(" ");
for(int i = 1; i <= n; i++){
a[i] = Integer.parseInt(str[i - 1]);
}
//初始化
for(int i = 1; i <= n; i++){
add(i, a[i] - a[i - 1]);
}
while(m-- != 0){
str = br.readLine().split(" ");
String op = str[0];
if(op.equals("C")){
int l = Integer.parseInt(str[1]);
int r = Integer.parseInt(str[2]);
int d = Integer.parseInt(str[3]);
add(l, d); add(r + 1, -d);
} else {
int x = Integer.parseInt(str[1]);
System.out.println(sum(x));
}
}
}
}

详细用法见blog:

https://blog.csdn.net/qq_52466006/article/details/120978631

线段树

作用:

  • 求连续区间和
  • 求长度、求面积
  • 染色问题

pushup

子节点算父节点

  • pushup(u)
  • build() 将一段区间初始化成线段树
  • modify()
    • 修改单点 easy
    • 修改区间 pushdown 懒标记
  • query() 查询某段区间信息 O(4logn)

空间大小:4n(n为节点个数)

build

1
2
3
4
5
6
7
8
static void build(int u, int l, int r){
tr[u] = new Node(l, r, 0);//可变
if(l == r)return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
//pushup(u);
}

query

1
2
3
4
5
6
7
8
9
10
static int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r)return tr[u].v;//树中节点已经被完全包含在[l, r]中

int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if(l <= mid)v = query(u << 1, l, r);//与左子树有交集
if(r > mid)v = Math.max(v, query(u << 1 | 1, l, r));//与右子树有交集

return v;
}

或:

1
2
3
4
5
6
7
8
9
10
11
12
13
static int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r)return tr[u].v;//树中节点已经被完全包含在[l, r]中

int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if(r <= mid)return query(u << 1, l, r);//全在左子树
else if(l > mid)return query(u << 1 | 1, l, r);//全在右子树
else {
int left = query(u << 1, l, r);
int right = query(u << 1 | 1, l , r);
return Math.max(left, right);
}
}

pushup

1
2
3
static void pushup(int u){ //由子节点的信息,来计算父节点的信息
tr[u].v = Math.max(tr[u << 1].v, tr[u << 1 | 1].v);
}

modify(单点修改)

1
2
3
4
5
6
7
8
9
static void modify(int u, int x, int v){
if(tr[u].l == x && tr[u].r == x)tr[u].v = v;//找到了叶节点
else {
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}

例题

最大数 https://www.acwing.com/problem/content/description/1277/

维护最大值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import java.io.*;

public class Main{
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static int N = 200010;
static class Node{
int l, r;
int v; //区间[l, r]中的最大值
public Node(int l, int r){
this.l = l; this.r = r;
}
}
static Node[] tr = new Node[N * 4];

static void build(int u, int l, int r){
tr[u] = new Node(l, r);
if(l == r)return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}

static void pushup(int u){ //由子节点的信息,来计算父节点的信息
tr[u].v = Math.max(tr[u << 1].v, tr[u << 1 | 1].v);
}

static int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r)return tr[u].v;//树中节点已经被完全包含在[l, r]中

int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if(l <= mid)v = query(u << 1, l, r);//与左子树有交集
if(r > mid)v = Math.max(v, query(u << 1 | 1, l, r));//与右子树有交集

return v;
}

static void modify(int u, int x, int v){
if(tr[u].l == x && tr[u].r == x)tr[u].v = v;//找到了叶节点
else {
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}

public static void main(String[] args) throws IOException {
int n = 0, last = 0;//n表示动态序列长度 last表示上一次查询结果
String[] str = br.readLine().split(" ");
int m = Integer.parseInt(str[0]);
int p = Integer.parseInt(str[1]);
build(1, 1, m);
while(m-- != 0){
str = br.readLine().split(" ");
String op = str[0];
int x = Integer.parseInt(str[1]);
if(op.equals("Q")){
last = query(1, n - x + 1, n);
System.out.println(last);
} else {
modify(1, n + 1, (int)(((long)last + x) % p));//注意防止溢出
n++;
}
}
}
}

维护区间和:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import java.io.*;

public class Main{
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
static int N = 100010;
static class Node{
int l, r, w;
public Node(int l, int r){
this.l = l; this.r = r;
}
}
static Node[] tr = new Node[N * 4];
static int[] a = new int[N];

static void build(int u, int l, int r){
tr[u] = new Node(l, r);
if(l == r) return;
else {
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}

static void pushup(int u){
tr[u].w = tr[u << 1].w + tr[u << 1 | 1].w;
}

static int query(int u, int l, int r){
if(tr[u].l >= l && tr[u].r <= r)return tr[u].w;
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid)sum += query(u << 1, l, r);
if(r > mid)sum += query(u << 1 | 1, l, r);
return sum;
}

static void modify(int u, int x, int v){
if(tr[u].l == x && tr[u].r == x)tr[u].w += v;
else {
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}

public static void main(String[] args) throws IOException {
String[] str = br.readLine().split(" ");
int n = Integer.parseInt(str[0]);
int m = Integer.parseInt(str[1]);
str = br.readLine().split(" ");
build(1, 1, n);
for(int i = 1; i <= n; i++){
a[i] = Integer.parseInt(str[i - 1]);
modify(1, i, a[i]);
}
while(m-- != 0){
str = br.readLine().split(" ");
int k = Integer.parseInt(str[0]);
int a = Integer.parseInt(str[1]);
int b = Integer.parseInt(str[2]);
if(k == 0){
bw.write(query(1, a, b) + "\n");
} else {
modify(1, a, b);
}
}
bw.flush();
}
}

pushdown(懒标记、延迟标记)

父节点算子节点

扫描线