AC自动机

AC自动机。

建树

AC自动机先要建一个 Trie树 。

  • ss 是有待匹配的字符串
  • trietrie 是 Trie树
  • ch[tmp]ch[tmp]tmptmp字符所在儿子的编号(没有为00
inline void build(){
    ll len=strlen(s),id=0;
    for(register int i=0;i<len;++i){
        ll tmp=s[i]-'a';
        if(!trie[id].ch[tmp])//如果没有这个儿子,就建立这个儿子
            trie[id].ch[tmp]=++cnt;
        id=trie[id].ch[tmp];//跳转到当前节点的这个儿子位置
    }
    ++trie[id].num;
    //此时 id 已经是这个字符串的结束位置了(即 root 到 id 中所有路径上的字符组成的字符串就是此时的s
    //因此 num 记录的是已这个节点结尾的字符串有多少个
    //也可以说记录相同的字符串有多少个
}

建立 failfail 数组

简单点来说,如果一个节点 uu ,根节点到 uu 的字符串为 abcdef ,而 rootroot 到另外一个节点 vv 的字符串为 bcdef 并且在所有的如此后缀中深度最深,那么这个点 vv 就是 uufailfail 要指向的地方。

建立 failfail 时需要更改 Trie树 的结构,引用一段 OI wiki 的话:

考虑字典树中当前的结点 uuuu 的父结点是 pp , 通过字符 c 的边指向 uu ,即 trie[p,c]=utrie[p,c]=u 。假设深度小于 uu 的所有结点的 fail 指针都已求得。

  1. 如果 trie[fail[p],c]trie[fail[p],c] 存在:则让 u 的 fail 指针指向 trie[fail[p],c]trie[fail[p],c] 。相当于在 ppfail[p]fail[p] 后面加一个字符 c,分别对应 uufail[u]fail[u]
  2. 如果 trie[fail[p],c]trie[fail[p],c] 不存在:那么我们继续找到 trie[fail[fail[p]],c]trie[fail[fail[p]],c] 。重复 1 的判断过程,一直跳 fail 指针直到根结点。
  3. 如果真的没有,就让 fail 指针指向根结点。

如此即完成了 fail[u]fail[u] 的构建。

至于为什么不用 whilewhile 循环而直接修改树的结构,其实很像并查集的路径压缩(和 kmp),稍微画一下图就好了...

inline void Fail(){
    queue<ll>q;
    for(register int i=0;i<26;++i)
        if(trie[0].ch[i]){
            trie[trie[0].ch[i]].fail=0;
            q.push(trie[0].ch[i]);
        }
    while(!q.empty()){
        ll u=q.front();q.pop();
        for(register int i=0;i<26;++i){
            ll v=trie[u].ch[i];
            if(v){
                trie[v].fail=trie[trie[u].fail].ch[i];
                q.push(v);
            }
            else
                trie[u].ch[i]=trie[trie[u].fail].ch[i];
        }
    }
}
//我觉得这是最好理解的一段了...

计算

题目要求是计算有多少个字符串是给定字符串 tt 的子串。

计算的时候从串头开始往下走,对于每个节点 vv 加上它的贡献值(以 vv 为终点的子串),走过之后将其变为 1-1 ,并直接跳 fail,如果一个点已经被走过了(计算过了),跳出循环。最后返回答案。简单来说,就是不断跳 fail ,不断往下搜索啊...很显然一个字符串只有可能有一个结束点啊!!!看 OI-wiki 上的图感性理解以下 。

inline ll AC(){
    ll len=strlen(t);
    ll id=0,ans=0;
    for(register int i=0;i<len;++i){
        ll tmp=t[i]-'a';
        id=trie[id].ch[tmp];
        ll v=id;
        while(v&&trie[v].num!=-1){
            ans+=trie[v].num;
            trie[v].num=-1;
            v=trie[v].fail;
        }
        //这里主要是找到此时的 root->v 的字符串的所有后缀并加入答案
    }
    return ans;
}

全代码

#include <bits/stdc++.h>
#define ll long long
#define maxn 1000101
using namespace std;
ll n,cnt;
char s[maxn],t[maxn];
struct Node{
    ll fail,num,ch[30];
}trie[maxn];

inline ll read(){
    ll x=0,f=0;char c=getchar();
    while(!isdigit(c))  f|=c=='-',c=getchar();
    while(isdigit(c))   x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return f?-x:x;
}

inline void write(ll x){
    if(x<0)     putchar('-'),x=-x;
    if(x>9)     write(x/10);
    putchar(x%10^48);
}

inline void build(){
    ll len=strlen(s),id=0;
    for(register int i=0;i<len;++i){
        ll tmp=s[i]-'a';
        if(!trie[id].ch[tmp])
            trie[id].ch[tmp]=++cnt;
        id=trie[id].ch[tmp];
    }
    ++trie[id].num;
}

inline void Fail(){
    queue<ll>q;
    for(register int i=0;i<26;++i)
        if(trie[0].ch[i]){
            trie[trie[0].ch[i]].fail=0;
            q.push(trie[0].ch[i]);
        }
    while(!q.empty()){
        ll u=q.front();q.pop();
        for(register int i=0;i<26;++i){
            ll v=trie[u].ch[i];
            if(v){
                trie[v].fail=trie[trie[u].fail].ch[i];
                q.push(v);
            }
            else
                trie[u].ch[i]=trie[trie[u].fail].ch[i];
        }
    }
}

inline ll AC(){
    ll len=strlen(t);
    ll id=0,ans=0;
    for(register int i=0;i<len;++i){
        ll tmp=t[i]-'a';
        id=trie[id].ch[tmp];
        ll v=id;
        while(v&&trie[v].num!=-1){
            ans+=trie[v].num;
            trie[v].num=-1;
            v=trie[v].fail;
        }
    }
    return ans;
}

int main(){
    n=read();
    for(register int i=1;i<=n;++i){
        cin>>s;
        build();
    }
    trie[0].fail=0;
    Fail();
    cin>>t;
    write(AC());
    return 0;
}