【专题】扩展KMP

简介

给定长度为n的文本串S和长度为m的模式串T,定义extend[i]为S[i…n-1]与T[0…m-1]的最长公共前缀长度,求extend[i]。当extend[x]==m时,则可知文本串S中包含模式串T,并且首位置为x,而这正是KMP算法处理的模式匹配问题。相较于KMP算法,扩展KMP算法能找到文本串S中所有模式串T的匹配,更一般地,可以知道文本串S中以每个字符开始的后缀与模式串T的最长公共前缀长度,时间复杂度O(n+m)。

扩展KMP

假设遍历到i时,已经求出extend[0…i-1]的值。我们记录下遍历过程中匹配成功的字符的最远位置为p,并且这次匹配的起始位置为a。换句话说,计算到x时,匹配成功的最远位置是x+extend[x]-1,p就是x=0…i-1得到的最大值,a就是对应的x。

设辅助数组next[i]表示T[i…m-1]和T[0…m-1]的最长公共前缀长度(注意这里与KMP算法中的next数组含义不同)。令len=next[i-a],讨论以下两种情况:

  1. i + len-1 < p
小于

由next数组定义可知 T[0…len-1] == T[i-a…i-a+len-1],并且由extend数组定义可知 S[a…p] == T[0…p-a],得到 S[i…i+len-1] == T[i-a…i-a+len-1] ,所以 S[i…i+len-1] == T[0…len-1]。并且 S[i+len-1…p] != T[len-1…i-a],不然违背了next数组最长公共前缀长度的定义。于是,无需任何比较就可以得出extend[i] = len。

  1. i + len-1 >= p
大于

我们可以看到,匹配最远的位置只到p,对于p之后的匹配结果是未知的,所以我们需要继续匹配S[p…n-1]和T[p-i…m-1]。匹配完之后,还需要更新a和p。

以上就是扩展KMP的主要思想,遍历即可,下面给出字符串下标从0开始的一种exKMP的实现,注意为了实现方便,代码中的p是上面分析的p+1,所以判断的时候是i+next[i-a]。而且同时用了变量j跟踪p在文本串S模式串T中的对应位置,注意遍历文本串S时,i在递增,j需要递减。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void exKMP(char* s, char* t, int n, int m)
{
int a, p;
for (int i=0, j=-1; i<n; i++, j--) //j即等于p与i的距离,其作用是判断i是否大于p(如果j<0,则i大于p)
{
if (j<0 || i+nxt[i-a] >= p)
{
if (j<0) p = i, j = 0; //如果i大于p
while (p<n && j<m && s[p]==t[j]) p++, j++;
extend[i] = j, a = i;
} else
extend[i] = nxt[i-a];
}
}

next数组

根据next数组定义,next[i]表示T[i…n]和T[0…n]的最长公共前缀的长度,其实就相当于模式串T与自身的匹配,不过这里next数组是从1开始的,next[0]匹配的肯定是整个串的长度m。有兴趣的可以试试写到一个函数中去,下面模板中给出分离的实现。再次强调C++中next是关键字(迭代器的一个函数),可以换个名字避免 CE

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
char s[mxn], t[mxn];
int nxt[mxn], extend[mxn];

void getnxt(char* t, int m)
{
int a, p; nxt[0] = m;
for (int i=1, j=-1; i<m; i++, j--)
{
if (j<0 || i+nxt[i-a] >= p)
{
if (j<0) p = i, j = 0;
while (p<m && t[p]==t[j]) p++, j++;
nxt[i] = j, a = i;
} else
nxt[i] = nxt[i-a];
}
}

void exKMP(char* s, char* t, int n, int m)
{
int a, p;
for (int i=0, j=-1; i<n; i++, j--) //j即等于p与i的距离,其作用是判断i是否大于p(如果j<0,则i大于p)
{
if (j<0 || i+nxt[i-a] >= p)
{
if (j<0) p = i, j = 0; //如果i大于p
while (p<n && j<m && s[p]==t[j]) p++, j++;
extend[i] = j, a = i;
} else
extend[i] = nxt[i-a];
}
}


int main()
{
scanf("%s%s", s, t);
getnxt(t, strlen(t));
exKMP(s, t, strlen(s), strlen(t));

for(int i=0; i<strlen(t); i++)
printf("%d ", nxt[i]);
printf("\n");

for(int i=0; i<strlen(s); i++)
printf("%d ", extend[i]);
printf("\n");

return 0;
}