diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go new file mode 100644 index 0000000..955021f --- /dev/null +++ b/pkg/filter/filter.go @@ -0,0 +1,7 @@ +package filter + +import "github.com/Qianlitp/crawlergo/pkg/model" + +type FilterHandler interface { + DoFilter(req *model.Request) bool +} diff --git a/pkg/filter/simple_filter.go b/pkg/filter/simple_filter.go index f0e1682..5fad1cc 100755 --- a/pkg/filter/simple_filter.go +++ b/pkg/filter/simple_filter.go @@ -9,21 +9,23 @@ import ( ) type SimpleFilter struct { - UniqueSet mapset.Set - HostLimit string + UniqueSet mapset.Set + HostLimit string + staticSuffixSet mapset.Set } -var ( - staticSuffixSet = config.StaticSuffixSet.Clone() -) +func NewSimpleFilter(host string) *SimpleFilter { + staticSuffixSet := config.StaticSuffixSet.Clone() -func init() { for _, suffix := range []string{"js", "css", "json"} { staticSuffixSet.Add(suffix) } + s := &SimpleFilter{UniqueSet: mapset.NewSet(), staticSuffixSet: staticSuffixSet, HostLimit: host} + return s } -/** +/* +* 需要过滤则返回 true */ func (s *SimpleFilter) DoFilter(req *model.Request) bool { @@ -45,7 +47,8 @@ func (s *SimpleFilter) DoFilter(req *model.Request) bool { return false } -/** +/* +* 请求去重 */ func (s *SimpleFilter) UniqueFilter(req *model.Request) bool { @@ -60,7 +63,8 @@ func (s *SimpleFilter) UniqueFilter(req *model.Request) bool { } } -/** +/* +* 静态资源过滤 */ func (s *SimpleFilter) StaticFilter(req *model.Request) bool { @@ -72,13 +76,14 @@ func (s *SimpleFilter) StaticFilter(req *model.Request) bool { if req.URL.FileExt() == "" { return false } - if staticSuffixSet.Contains(req.URL.FileExt()) { + if s.staticSuffixSet.Contains(req.URL.FileExt()) { return true } return false } -/** +/* +* 只保留指定域名的链接 */ func (s *SimpleFilter) DomainFilter(req *model.Request) bool { diff --git a/pkg/filter/smart_filter.go b/pkg/filter/smart_filter.go index b515ce5..a8c21cf 100755 --- a/pkg/filter/smart_filter.go +++ b/pkg/filter/smart_filter.go @@ -16,8 +16,8 @@ import ( ) type SmartFilter struct { - StrictMode bool - SimpleFilter SimpleFilter + StrictMode bool + *SimpleFilter filterLocationSet mapset.Set // 非逻辑型参数的位置记录 全局统一标记过滤 filterParamKeyRepeatCount sync.Map filterParamKeySingleValues sync.Map // 所有参数名重复数量统计 @@ -74,7 +74,8 @@ var onlyAlphaNumRegex = regexp.MustCompile(`^[0-9a-zA-Z]+$`) var markedStringRegex = regexp.MustCompile(`^{{.+}}$`) var htmlReplaceRegex = regexp.MustCompile(`\.shtml|\.html|\.htm`) -func (s *SmartFilter) Init() { +func NewSmartFilter(base *SimpleFilter, strictMode bool) *SmartFilter { + s := &SmartFilter{} s.filterLocationSet = mapset.NewSet() s.filterParamKeyRepeatCount = sync.Map{} s.filterParamKeySingleValues = sync.Map{} @@ -83,9 +84,13 @@ func (s *SmartFilter) Init() { s.filterPathParamEmptyValues = sync.Map{} s.filterParentPathValues = sync.Map{} s.uniqueMarkedIds = mapset.NewSet() + s.SimpleFilter = base + s.StrictMode = strictMode + return s } -/** +/* +* 智能去重 可选严格模式 @@ -149,7 +154,8 @@ func (s *SmartFilter) DoFilter(req *model.Request) bool { return false } -/** +/* +* Query的Map对象会自动解码,所以对RawQuery进行预先的标记 */ func (s *SmartFilter) preQueryMark(rawQuery string) string { @@ -163,7 +169,8 @@ func (s *SmartFilter) preQueryMark(rawQuery string) string { return rawQuery } -/** +/* +* 对GET请求的参数和路径进行标记 */ func (s *SmartFilter) getMark(req *model.Request) { @@ -199,7 +206,8 @@ func (s *SmartFilter) getMark(req *model.Request) { req.Filter.UniqueId = getMarkedUniqueID(req) } -/** +/* +* 对POST请求的参数和路径进行标记 */ func (s *SmartFilter) postMark(req *model.Request) { @@ -227,7 +235,8 @@ func (s *SmartFilter) postMark(req *model.Request) { req.Filter.UniqueId = getMarkedUniqueID(req) } -/** +/* +* 标记参数名 */ func markParamName(paramMap map[string]interface{}) map[string]interface{} { @@ -248,7 +257,8 @@ func markParamName(paramMap map[string]interface{}) map[string]interface{} { return markedParamMap } -/** +/* +* 标记参数值 */ func (s *SmartFilter) markParamValue(paramMap map[string]interface{}, req model.Request) map[string]interface{} { @@ -336,7 +346,8 @@ func (s *SmartFilter) markParamValue(paramMap map[string]interface{}, req model. return markedParamMap } -/** +/* +* 标记路径 */ func MarkPath(path string) string { @@ -376,7 +387,8 @@ func MarkPath(path string) string { return newPath } -/** +/* +* 全局数值型参数过滤 */ func (s *SmartFilter) globalFilterLocationMark(req *model.Request) { @@ -398,7 +410,8 @@ func (s *SmartFilter) globalFilterLocationMark(req *model.Request) { } } -/** +/* +* 进行全局重复参数名、参数值、路径的统计标记 之后对超过阈值的部分再次打标记 */ @@ -483,7 +496,8 @@ func (s *SmartFilter) repeatCountStatistic(req *model.Request) { } } -/** +/* +* 对重复统计之后,超过阈值的部分再次打标记 */ func (s *SmartFilter) overCountMark(req *model.Request) { @@ -571,7 +585,8 @@ func (s *SmartFilter) calcFragmentID(fragment string) string { return fakeReq.Filter.UniqueId } -/** +/* +* 计算标记后的唯一请求ID */ func getMarkedUniqueID(req *model.Request) string { @@ -593,7 +608,8 @@ func getMarkedUniqueID(req *model.Request) string { return tools.StrMd5(uniqueStr) } -/** +/* +* 计算请求参数的key标记后的唯一ID */ func getKeysID(dataMap map[string]interface{}) string { @@ -609,7 +625,8 @@ func getKeysID(dataMap map[string]interface{}) string { return tools.StrMd5(idStr) } -/** +/* +* 计算请求参数标记后的唯一ID */ func getParamMapID(dataMap map[string]interface{}) string { @@ -630,14 +647,16 @@ func getParamMapID(dataMap map[string]interface{}) string { return tools.StrMd5(idStr) } -/** +/* +* 计算PATH标记后的唯一ID */ func getPathID(path string) string { return tools.StrMd5(path) } -/** +/* +* 判断字符串中是否存在以下特殊符号 */ func hasSpecialSymbol(str string) bool { diff --git a/pkg/filter/smart_filter_test.go b/pkg/filter/smart_filter_test.go index 2ce1fb8..2d18ac4 100644 --- a/pkg/filter/smart_filter_test.go +++ b/pkg/filter/smart_filter_test.go @@ -31,11 +31,10 @@ var ( // completeUrls = []string{ // "https://test.local.com:1234/adfatd/123456/sx14xi?user=crawlergo&pwd=fa1424&end=1#/user/info", // } - smart = SmartFilter{} + smart = NewSmartFilter(NewSimpleFilter(""), true) ) func TestDoFilter_countFragment(t *testing.T) { - smart.Init() reqs := []model.Request{} for _, fu := range fragmentUrls { url, err := model.GetUrl(fu) diff --git a/pkg/task_main.go b/pkg/task_main.go index 52d9249..15c54b5 100755 --- a/pkg/task_main.go +++ b/pkg/task_main.go @@ -6,6 +6,7 @@ import ( "github.com/Qianlitp/crawlergo/pkg/config" engine2 "github.com/Qianlitp/crawlergo/pkg/engine" + "github.com/Qianlitp/crawlergo/pkg/filter" filter2 "github.com/Qianlitp/crawlergo/pkg/filter" "github.com/Qianlitp/crawlergo/pkg/logger" "github.com/Qianlitp/crawlergo/pkg/model" @@ -14,16 +15,16 @@ import ( ) type CrawlerTask struct { - Browser *engine2.Browser // - RootDomain string // 当前爬取根域名 用于子域名收集 - Targets []*model.Request // 输入目标 - Result *Result // 最终结果 - Config *TaskConfig // 配置信息 - smartFilter filter2.SmartFilter // 过滤对象 - Pool *ants.Pool // 协程池 - taskWG sync.WaitGroup // 等待协程池所有任务结束 - crawledCount int // 爬取过的数量 - taskCountLock sync.Mutex // 已爬取的任务总数锁 + Browser *engine2.Browser // + RootDomain string // 当前爬取根域名 用于子域名收集 + Targets []*model.Request // 输入目标 + Result *Result // 最终结果 + Config *TaskConfig // 配置信息 + filter filter.FilterHandler // 过滤对象 + Pool *ants.Pool // 协程池 + taskWG sync.WaitGroup // 等待协程池所有任务结束 + crawledCount int // 爬取过的数量 + taskCountLock sync.Mutex // 已爬取的任务总数锁 } type Result struct { @@ -40,18 +41,26 @@ type tabTask struct { req *model.Request } -/** +/* +* 新建爬虫任务 */ func NewCrawlerTask(targets []*model.Request, taskConf TaskConfig) (*CrawlerTask, error) { crawlerTask := CrawlerTask{ Result: &Result{}, Config: &taskConf, - smartFilter: filter2.SmartFilter{ - SimpleFilter: filter2.SimpleFilter{ - HostLimit: targets[0].URL.Host, - }, - }, + } + + baseFilter := filter.NewSimpleFilter(targets[0].URL.Host) + + if taskConf.FilterMode == config.SmartFilterMode { + crawlerTask.filter = filter.NewSmartFilter(baseFilter, false) + + } else if taskConf.FilterMode == config.StrictFilterMode { + crawlerTask.filter = filter.NewSmartFilter(baseFilter, true) + + } else { + crawlerTask.filter = baseFilter } if len(targets) == 1 { @@ -98,8 +107,6 @@ func NewCrawlerTask(targets []*model.Request, taskConf TaskConfig) (*CrawlerTask crawlerTask.Browser = engine2.InitBrowser(taskConf.ChromiumPath, taskConf.ExtraHeaders, taskConf.Proxy, taskConf.NoHeadless) crawlerTask.RootDomain = targets[0].URL.RootDomain() - crawlerTask.smartFilter.Init() - // 创建协程池 p, _ := ants.NewPool(taskConf.MaxTabsCount) crawlerTask.Pool = p @@ -107,7 +114,8 @@ func NewCrawlerTask(targets []*model.Request, taskConf TaskConfig) (*CrawlerTask return &crawlerTask, nil } -/** +/* +* 根据请求列表生成tabTask协程任务列表 */ func (t *CrawlerTask) generateTabTask(req *model.Request) *tabTask { @@ -119,7 +127,8 @@ func (t *CrawlerTask) generateTabTask(req *model.Request) *tabTask { return &task } -/** +/* +* 开始当前任务 */ func (t *CrawlerTask) Run() { @@ -148,7 +157,7 @@ func (t *CrawlerTask) Run() { var initTasks []*model.Request for _, req := range t.Targets { - if t.smartFilter.DoFilter(req) { + if t.filter.DoFilter(req) { logger.Logger.Debugf("filter req: " + req.URL.RequestURI()) continue } @@ -183,7 +192,8 @@ func (t *CrawlerTask) Run() { t.Result.SubDomainList = SubDomainCollect(t.Result.AllReqList, t.RootDomain) } -/** +/* +* 添加任务到协程池 添加之前实时过滤 */ @@ -208,7 +218,8 @@ func (t *CrawlerTask) addTask2Pool(req *model.Request) { }() } -/** +/* +* 单个运行的tab标签任务,实现了workpool的接口 */ func (t *tabTask) Task() { @@ -232,23 +243,12 @@ func (t *tabTask) Task() { t.crawlerTask.Result.resultLock.Unlock() for _, req := range tab.ResultList { - if t.crawlerTask.Config.FilterMode == config.SimpleFilterMode { - if !t.crawlerTask.smartFilter.SimpleFilter.DoFilter(req) { - t.crawlerTask.Result.resultLock.Lock() - t.crawlerTask.Result.ReqList = append(t.crawlerTask.Result.ReqList, req) - t.crawlerTask.Result.resultLock.Unlock() - if !engine2.IsIgnoredByKeywordMatch(*req, t.crawlerTask.Config.IgnoreKeywords) { - t.crawlerTask.addTask2Pool(req) - } - } - } else { - if !t.crawlerTask.smartFilter.DoFilter(req) { - t.crawlerTask.Result.resultLock.Lock() - t.crawlerTask.Result.ReqList = append(t.crawlerTask.Result.ReqList, req) - t.crawlerTask.Result.resultLock.Unlock() - if !engine2.IsIgnoredByKeywordMatch(*req, t.crawlerTask.Config.IgnoreKeywords) { - t.crawlerTask.addTask2Pool(req) - } + if !t.crawlerTask.filter.DoFilter(req) { + t.crawlerTask.Result.resultLock.Lock() + t.crawlerTask.Result.ReqList = append(t.crawlerTask.Result.ReqList, req) + t.crawlerTask.Result.resultLock.Unlock() + if !engine2.IsIgnoredByKeywordMatch(*req, t.crawlerTask.Config.IgnoreKeywords) { + t.crawlerTask.addTask2Pool(req) } } }