KMP算法的介绍参见维基百科: https://en.wikipedia.org/wiki/Knuth–Morris–Pratt_algorithm 这篇文章的解释不错: http://www.ruanyifeng.com/blog/2013/05/Knuth–Morris–Pratt_algorithm.html 算法的关键在于next数组的计算,相当于自己跟自己进行匹配。 算法的时间复杂度为O(m+n)。 这张图有助于对算法的理解:
(图片来自http://wiki.jikexueyuan.com/project/kmp-algorithm/define.html) 算法导论的伪代码:
C++代码:(与算法导论稍有不同,当然next数组还有可以优化的地方)
int strStr(string haystack, string needle) { int hlen = haystack.size(); int nlen = needle.size(); if(nlen == 0) { return 0; } if(hlen == 0) { return -1; } vector<int> pattern(nlen); GeneratePattern(needle, pattern); return Match(haystack, needle, pattern); } void GeneratePattern(const string &str, vector<int> &pattern) { int len = str.size(); pattern[0] = -1; int j = 1; int k = -1; while(j < len) { if(k == -1 || str[j - 1] == str[k]) { k++; pattern[j] = k; j++; } else { k = pattern[k]; } } } int Match(const string &haystack, const string &needle, const vector<int> &pattern) { int hlen = haystack.size(); int nlen = needle.size(); int j = 0; int k = 0; // not -1 while(j < hlen) { if(k == -1 || haystack[j] == needle[k]) { j++; k++; if(k == nlen) { return j - k; } } else { k = pattern[k]; } } return -1; }
Golang代码:
package main import ( "fmt" ) func main() { fmt.Println(StrStr("abaa", "aa")) } func StrStr(haystack string, needle string) int { m := len(haystack) n := len(needle) if n == 0 { return 0 } if m == 0 { return -1 } pattern := make([]int, n) GeneratePattern(needle, pattern) return Match(haystack, needle, pattern) } func GeneratePattern(str string, pattern []int) { length := len(str) pattern[0] = -1 j := 1 k := -1 for j < length { if k == -1 || str[j - 1] == str[k] { k++ pattern[j] = k j++ } else { k = pattern[k] } } } func Match(haystack string, needle string, pattern []int) int { hlen := len(haystack) nlen := len(needle) j := 0 k := 0 for j < hlen { if k == -1 || haystack[j] == needle[k] { j++ k++ if k == nlen { return j - k } } else { k = pattern[k] } } return -1 }Python代码:
def strStr(haystack, needle): """ :type haystack: str :type needle: str :rtype: int """ hlen = len(haystack) nlen = len(needle) if nlen == 0: return 0 if hlen == 0: return -1 pattern = [0] * nlen GeneratePattern(needle, pattern) return Match(haystack, needle, pattern) def GeneratePattern(string, pattern): length = len(string) pattern[0] = -1 j = 1 k = -1 while j < length: if k == -1 or string[j - 1] == string[k]: k += 1 pattern[j] = k j += 1 else: k = pattern[k] def Match(haystack, needle, pattern): hlen = len(haystack) nlen = len(needle) j = 0 k = 0 while j < hlen: if k == -1 or haystack[j] == needle[k]: j += 1 k += 1 if k == nlen: return j - k else: k = pattern[k] return -1 if __name__ == '__main__': print strStr("abaa", "aa")