前缀树又叫字典树,通常用来高效地查询字符串,比如查询库中是否有以某个子字符串为前缀的字符串,某个字符串出现的次数等。前缀树是 N 叉树的一种特殊形式,每一个节点会有多个子节点,通往不同子节点的 路径上 有着不同的字符,子节点中包含两种信息:pass(经过此字符的次数),end(以此字符结尾的次数)。说多了没用,直接上图:

0-‘a’,1-’b’,2-‘c’,每个节点对应的下标即代表相应字母,除 root 外的其他节点中包含着该字母的信息。
细心的同学应该发现,root 的 end 域只可能为 0,因为这代表着插入了一个空字符串,这是不被允许的。所以我们似乎可以利用这个 end 域做些坏事。容易知道,root 的 pass 域代表着此前缀树中一共有多少个字符串(相同字符串会被重复计数),那么,我如果想知道一共有多少种字符串呢?此时,就要利用 root 的 end 域来记录了。详细注释已在代码给出,不在赘述。代码如下(仅支持小写字母):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include<iostream>
#include<string>
using std::string;
using std::cout;
using std::endl;
struct node
{
int pass;
int end;
node* nexts[26];
node():pass(0), end(0)
{
for(int i = 0; i < 26; i++)
nexts[i] = NULL;
}
};

class trieTree
{
private:
node* root;
public:
trieTree(){root = new node;}
void insert(string str);
int search(string str);
int prefixSearch(string str);
int delstr(string str, int times);
};

void trieTree::insert(string str)
{
if(str.size() == 0)
return;
int n;
node* tmp = root;
for(int i = 0; i < str.size(); i++)
{
n = str[i] - 'a';
if(tmp->nexts[n] == NULL)
tmp->nexts[n] = new node;
tmp = tmp->nexts[n];
tmp->pass++;
}
if(tmp->end == 0)
root->end++;
tmp->end++;
root->pass++;
};

int trieTree::search(string str)
{
if(str.size() == 0)//!!!
return 0;
node* tmp = root;
int i = 0;
int n;
for(i = 0; i < str.size(); i++)
{
n = str[i] - 'a';
if(tmp->nexts[n] == NULL)
return 0;
tmp = tmp->nexts[n];
}
return tmp->end;
}


int trieTree::prefixSearch(string str)
{
if(str.size() == 0)//!!!
return 0;
node* tmp = root;
int i = 0;
int n;
for(i = 0; i < str.size(); i++)
{
n = str[i] - 'a';
if(tmp->nexts[n] == NULL)
return 0;
tmp = tmp->nexts[n];
}
return tmp->pass;
}

int trieTree::delstr(string str, int times)//times为删除次数,如果为0则为删除所有该字符
{
int counts = this->search(str);
if(counts == 0)
return 0;
int t = times == 0 ? counts : (times < counts ? times : counts);
node* tmp = root;
int n;
for(int i = 0; i < str.size(); i++)
{
n = str[i] - 'a';
tmp = tmp->nexts[n];
tmp->pass -= t;
}
tmp->end -= t;
if(t != counts)
return t;
for(int i = 0; i < str.size(); i++)
{
int k = 0;
tmp = root;
for(; k < str.size()-1-i; k++)
{
n = str[k] - 'a';
tmp = tmp->nexts[n];
}
n = str[k] - 'a';
if(tmp->nexts[n]->pass == 0)
{
delete tmp->nexts[n];
tmp->nexts[n] = NULL;
}
}
return counts;
}

以上是链式实现,当字符种类很多时,一般就采用哈希实现,此方式后面有时间再给出。