definsert(str_input): global cnt root = 0 for ch in str_input: index = ord(ch) - ord('a') if tree[root][index] == 0: cnt = cnt + 1 tree[root][index] = cnt # sum_[tree[root][index]] += 1 前缀计数 root = tree[root][index] end[root] += 1# 标记结束
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defget_fail(): fail[0] = 0 que = queue.Queue() for i inrange(27): if tree[0][i]: que.put(tree[0][i]) whilenot que.empty(): v = que.get() for i inrange(27): if tree[v][i]: fail[tree[v][i]] = tree[fail[v]][i] que.put(tree[v][i]) else: tree[v][i] = tree[fail[v]][i]
1 2 3 4 5 6 7 8 9 10 11
defquery(str_input): root = 0 ans = 0 for ch inlist(str_input): index = ord(ch) - ord('a') u = root = tree[root][index] while u and end[u] != -1: ans += end[u] end[u] = -1 u = fail[u] return ans