Algorithm -- strstr 的 Rabin Karp 解法
C’s strStr() or Java’s indexOf() 是两个常用函数,也是一道经典的算法问题。
比如LeetCode第28题,几乎所有旁友都做过这题。
Input: haystack = "aaaaaaaaaaaaaaaaaabaaaaaa", needle = "aab"
Output: 16
相信很多CS或非CS专业的同学听说过KMP,甚至仔细学习过。但是在面试中是不会要求实现KMP这样的算法的。
假设函数签名是这样的: public int strStr(String s, String t)
,其中s
表示大String
,t
表示小String
;要求返回的是t在s中第一次出现的起始index,如果找不到则返回-1。
为了简单起见,以下代码都假设给定的String不为null,并假设 S.length() >= T.length()。
举个例子:
S=
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
T | h | e | q | u | i | c | k | b | r | o | w | n | f | o | x | j | u | m | p |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | |||||
s | o | v | e | r | t | h | e |
l |
a |
z |
y |
d | o | g |
T=
0 | 1 | 2 | 3 | 4 |
e | l | a | z | y |
方法一: Brute Force
从S中依次选取每一段和T相同长度的字符串,命名为X,如果X和T相等,则找到并返回X的起始index。code中可以用S.substring(i, i + T.length())
代替X,如果找到直接返回i.
public int strStr(String S, String T) {
for (int i = 0; i < S.length() - T.length() + 1; ++i) {
if (S.substring(i, i + T.length()).equals(T) {
return i;
}
}
return -1;
}
这样的写法,时间复杂度是O(S * T),这里假设substring()和equals()的时间复杂度都是O(T)。参考 Time Complexity of Java’s Substring
由于substring()会创造额外的空间,所以空间开销很大O(S * T)。 如果对空间进行一点优化,不创造额外的String,直接每个字符比较,可以写成:
public int strStr(String S, String T) {
int i = 0, j = 0;
while (i < S.length() - T.length() + 1) {
while (j < T.length() && S.charAt(i+j) == T.charAt(j)) {
j++;
}
if (j == T.length()) {
return i;
}
j = 0;
i++;
}
return -1;
}
这样的写法,时间复杂度最坏的情况下依然是O(S * T),但是不消耗额外的空间。
方法二: Rabin-Karp
从方法一其实是从S中选取一个个和T相等长度的窗口X,再和T进行比较。 方法一降低了选取选取窗口的时间复杂度,但并未降低窗口X和T比较的时间复杂度。 也就是说,如果能降低窗口X和T比较的时间复杂度的话,总体时间复杂度就能进一步降低。
Rabin-Karp就是这样一个算法,其中所使用的技巧就是不直接比较String,而是比较String的特征值,比如我们很熟悉的hash value. 算法大致是这样的:
public int strStr(String S, String T) {
for (int i = 0; i < S.length() - T.length() + 1; i++) {
if (S.substring(i, i + T.length()).hashCode() == T.hashCode()) {
return i;
}
}
return -1;
}
这样的代码在LeetCode上是可以通过的,但是有两个很严重的错误:
- substring() 依然要消耗O(T)的时间以及O(T)的空间; ← 主要矛盾
- String.hashCode() 的计算要消耗O(T)的时间; ← 主要矛盾
- String A.hashCode() == B.hashCode() 并不能保证 A.equals(B),因为有hash collision的可能存在(虽然可能性很小). ← 次要矛盾
Rabin-Karp中用来解决主要矛盾的方法叫做: Rolling Hash;
换做大家更熟悉的叫法可以是 Sliding Window Hash.
WikiPedia上给了一个简单的公式:s[i+1..i+m] = s[i..i+m-1] - s[i] + s[i+m]
这个方法是靠 递推 来解决主要矛盾的。这个 递推 所用的时间越少,就越能降低总体时间复杂度。
现在来看下这个递推是如何实现的:
Step_1: 一个普通的hash计算方法
hash(S[0..k]) = (S[0] * prime^k + S[1] * prime^(k-1) + … + S[k] * prime^0) % largePrime
hash(S[1..k+1]) = (S[1] * prime^k + S[2] * prime^(k-1) + … + S[k+1] * prime^0) % largePrime
Step_2: 从S[0..k] 到 S[1..k+1]
S[1] * prime^k + S[2] * prime^(k-1) + … + S[k+1] * prime^0
= ( S[0] * prime^k + S[1] * prime^(k-1) + … + S[k] * prime^0 ) * prime
- S[0] * prime^(k+1)
+ S[k+1] * prime^0
Step_3: 结合Step_1 和 Step2, 处理计算结果overflow的情况,引入 mod 运算
如果 A == B,那么显然 A % Q == B % Q;
现在有 X = Y * p - c * p^k + d,同样可以有 X % Q = (Y * p - c * p^k + d ) % Q,
即 X % Q = (Y * p % Q - c * p^k % Q + d % Q) % Q,
因为 mod 运算满足分配率
根据以上三步,可以得到一个有用的算法,即在sliding window的每个中间过程中都可以进行mod运算而不影响结果的正确性,同时又避免了overflow的问题。
为了方便阅读,我做了很多简化和hardcode
public int strStr(String S, String T) {
return strStr(S.toCharArray(), T.toCharArray());
}
private int strstr(char[] large, char[] small) {
int seed = 1;
int rollingHash = 0;
int smallHash = 0;
for (int i = 0; i < small.length; ++i) {
if (i > 0) {
seed = seed * 31 % 101;
}
smallHash = (smallHash * 31 % 101 + small[i]) % 101;
rollingHash = (rollingHash * 31 % 101 + large[i]) % 101;
}
for (int i = 0; i < large.length - small.length + 1; ++i) {
if (i > 0) {
rollingHash = ((rollingHash - seed * large[i-1]) % 101 + 101) % 101;
rollingHash = (rollingHash * 31 % 101 + large[i + small.length - 1] ) % 101 ;
}
if (rollingHash == smallHash && equal(large, i, small)) {
return i;
}
}
return -1;
}
private boolean equal(char[] large, int i, char small) { …略... } // O(T)时间一个个字符比较
其中,smallHash就是hash(T),rollingHash是S中每一个滑动窗口X的hash(X); 在不考虑 次要矛盾(hash collision) 的情况下,只要 hash(T) == hash(X),就能推出T.equals(X).
Seed的作用: seed % largePrime
必须为1。
还记得__Step_2__中那个被减去的S[0] * prime^(k+1)
吗?由于prime^(k+1)
比较大,计算需要花费不少时间,并且被频繁使用,所以这里提前把prime^(k+1)
的值给计算了出来,并且% largePrime
也不影响正确性。因为seed == prime^(k+1) % largePrime
。
举个例子:
为了简单直观,我用S = “thelazydog”, T = “lazy”;其中prime用了31,largePrime用了101。
这是以长度5为单位的每个滑动窗口的hash值,读者可以自行进行计算。
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
t | h | e | l | a | z | y | d | o | g |
58 | 91 | 31 |
51 | 17 |
0 | 1 | 2 | 3 | 4 |
e | l | a | z | y |
31 |
Rabin-Karp算法的时间复杂度约等于O(S + T),因为hash collision的几率可以根据需要设计得很小以至忽略不计。空间复杂度同样是常数级别的。(为了代码简洁我把String转换成了Char array,实际算法中不需要这样的处理)