diff --git a/engine/dispatcher/dispatcher.go b/engine/dispatcher/dispatcher.go index 6e905991..4ac5d9b9 100644 --- a/engine/dispatcher/dispatcher.go +++ b/engine/dispatcher/dispatcher.go @@ -6,12 +6,19 @@ package dispatcher import ( "errors" + "fmt" "log/slog" "os" "time" "github.com/caffix/queue" et "github.com/owasp-amass/amass/v4/engine/types" + oam "github.com/owasp-amass/open-asset-model" +) + +const ( + MinPipelineQueueSize = 100 + MaxPipelineQueueSize = 500 ) type dis struct { @@ -35,7 +42,7 @@ func NewDispatcher(l *slog.Logger, r et.Registry, mgr et.SessionManager) et.Disp completed: queue.NewQueue(), } - go d.collectEvents() + go d.maintainPipelines() return d } @@ -48,19 +55,23 @@ func (d *dis) Shutdown() { close(d.done) } -func (d *dis) collectEvents() { - t := time.NewTicker(100 * time.Millisecond) - defer t.Stop() +func (d *dis) maintainPipelines() { + ctick := time.NewTicker(5 * time.Second) + defer ctick.Stop() + qtick := time.NewTicker(time.Second) + defer qtick.Stop() loop: for { select { case <-d.done: break loop + case <-qtick.C: + d.fillPipelineQueues() case <-d.completed.Signal(): if element, ok := d.completed.Next(); ok { d.completedCallback(element) } - case <-t.C: + case <-ctick.C: if element, ok := d.completed.Next(); ok { d.completedCallback(element) } @@ -69,6 +80,43 @@ loop: d.completed.Process(d.completedCallback) } +func (d *dis) fillPipelineQueues() { + sessions := d.mgr.GetSessions() + if len(sessions) == 0 { + return + } + + var ptypes []oam.AssetType + for _, atype := range oam.AssetList { + if ap, err := d.reg.GetPipeline(atype); err == nil { + if ap.Queue.Len() < MinPipelineQueueSize { + ptypes = append(ptypes, atype) + } + } + } + + numRequested := MaxPipelineQueueSize / len(sessions) + for _, s := range sessions { + if s == nil || s.Done() { + continue + } + for _, atype := range ptypes { + if entities, err := s.Queue().Next(atype, numRequested); err == nil && len(entities) > 0 { + for _, entity := range entities { + event := &et.Event{ + Name: fmt.Sprintf("%s - %s", string(atype), entity.Asset.Key()), + Entity: entity, + Session: s, + } + if err := d.appendToPipelineQueue(event); err != nil { + s.Log().WithGroup("event").With("name", event.Name).Error(err.Error()) + } + } + } + } + } +} + func (d *dis) completedCallback(data interface{}) { ede, ok := data.(*et.EventDataElement) if !ok { @@ -95,21 +143,40 @@ func (d *dis) DispatchEvent(e *et.Event) error { } else if e.Entity == nil || e.Entity.Asset == nil { return errors.New("the event has no associated entity or asset") } + // do not schedule the same asset more than once + set := e.Session.EventSet() + if set.Has(e.Entity.ID) { + return errors.New("this event was processed previously") + } ap, err := d.reg.GetPipeline(e.Entity.Asset.AssetType()) if err != nil { return err } - e.Dispatcher = d - // do not schedule the same asset more than once - set := e.Session.EventSet() - if set.Has(e.Entity.ID) { - return errors.New("this event was processed previously") + if qlen := ap.Queue.Len(); e.Meta != nil || qlen < MinPipelineQueueSize { + if err := d.appendToPipelineQueue(e); err != nil { + return err + } + return nil + } + + return e.Session.Queue().Append(e.Entity) +} + +func (d *dis) appendToPipelineQueue(e *et.Event) error { + if e == nil || e.Session == nil || e.Entity == nil || e.Entity.Asset == nil { + return errors.New("the event is nil") } - set.Insert(e.Entity.ID) + ap, err := d.reg.GetPipeline(e.Entity.Asset.AssetType()) + if err != nil { + return err + } + + e.Dispatcher = d if data := et.NewEventDataElement(e); data != nil { + e.Session.EventSet().Insert(e.Entity.ID) data.Queue = d.completed ap.Queue.Append(data) // increment the number of events processed in the session diff --git a/engine/sessions/manager.go b/engine/sessions/manager.go index 30f0861e..ba21ddda 100644 --- a/engine/sessions/manager.go +++ b/engine/sessions/manager.go @@ -111,6 +111,18 @@ func (r *manager) CancelSession(id uuid.UUID) { delete(r.sessions, id) } +func (r *manager) GetSessions() []et.Session { + r.RLock() + defer r.RUnlock() + + sessions := make([]et.Session, 0, len(r.sessions)) + for _, s := range r.sessions { + sessions = append(sessions, s) + } + + return sessions +} + // GetSession: returns a session from a session storage. func (r *manager) GetSession(id uuid.UUID) et.Session { r.RLock() diff --git a/engine/sessions/session.go b/engine/sessions/session.go index d3a9680a..0961074e 100644 --- a/engine/sessions/session.go +++ b/engine/sessions/session.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/caffix/stringset" @@ -24,6 +25,8 @@ import ( "github.com/owasp-amass/asset-db/repository" "github.com/owasp-amass/asset-db/repository/neo4j" "github.com/owasp-amass/asset-db/repository/sqlrepo" + dbt "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" "github.com/yl2chen/cidranger" ) @@ -34,9 +37,10 @@ type Session struct { cfg *config.Config scope *scope.Scope db repository.Repository + queue *sessionQueue dsn string dbtype string - c *cache.Cache + cache *cache.Cache ranger cidranger.Ranger tmpdir string stats *et.SessionStats @@ -80,10 +84,15 @@ func CreateSession(cfg *config.Config) (et.Session, error) { return nil, err } - s.c, err = cache.New(c, s.db, time.Minute) - if err != nil || s.c == nil { + s.cache, err = cache.New(c, s.db, time.Minute) + if err != nil || s.cache == nil { return nil, errors.New("failed to create the session cache") } + + s.queue = newSessionQueue(s) + s.log.Info("Session initialized") + s.log.Info("Temporary directory created", slog.String("dir", s.tmpdir)) + s.log.Info("Database connection established", slog.String("dsn", s.dsn)) return s, nil } @@ -112,7 +121,11 @@ func (s *Session) DB() repository.Repository { } func (s *Session) Cache() *cache.Cache { - return s.c + return s.cache +} + +func (s *Session) Queue() et.SessionQueue { + return s.queue } func (s *Session) CIDRanger() cidranger.Ranger { @@ -221,3 +234,71 @@ func (s *Session) createFileRepo(fname string) (repository.Repository, error) { } return c, nil } + +type sessionQueue struct { + sync.Mutex + session *Session + q map[string][]string +} + +func newSessionQueue(s *Session) *sessionQueue { + return &sessionQueue{ + session: s, + q: make(map[string][]string), + } +} + +func (sq *sessionQueue) Append(e *dbt.Entity) error { + sq.Lock() + defer sq.Unlock() + + if e == nil { + return errors.New("entity is nil") + } + if e.Asset == nil { + return errors.New("asset is nil") + } + + key := string(e.Asset.AssetType()) + if key == "" { + return errors.New("asset type is empty") + } + if _, found := sq.q[key]; !found { + sq.q[key] = make([]string, 0) + } + if e.ID == "" { + return errors.New("entity ID is empty") + } + + sq.q[key] = append(sq.q[key], e.ID) + return nil +} + +func (sq *sessionQueue) Next(atype oam.AssetType, num int) ([]*dbt.Entity, error) { + var ids []string + key := string(atype) + + sq.Lock() + if q, found := sq.q[key]; found { + if len(q) > num { + ids = q[:num] + sq.q[key] = q[num:] + } else { + ids = q + delete(sq.q, key) + } + } + sq.Unlock() + + var results []*dbt.Entity + for _, id := range ids { + if e, err := sq.session.Cache().FindEntityById(id); err == nil { + results = append(results, e) + } + } + + if len(results) == 0 { + return nil, errors.New("no entities found") + } + return results, nil +} diff --git a/engine/types/sessions.go b/engine/types/sessions.go index c2bfd682..078beb16 100644 --- a/engine/types/sessions.go +++ b/engine/types/sessions.go @@ -16,6 +16,8 @@ import ( "github.com/owasp-amass/amass/v4/engine/sessions/scope" "github.com/owasp-amass/asset-db/cache" "github.com/owasp-amass/asset-db/repository" + dbt "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" "github.com/yl2chen/cidranger" ) @@ -27,6 +29,7 @@ type Session interface { Scope() *scope.Scope DB() repository.Repository Cache() *cache.Cache + Queue() SessionQueue CIDRanger() cidranger.Ranger TmpDir() string Stats() *SessionStats @@ -35,6 +38,11 @@ type Session interface { Kill() } +type SessionQueue interface { + Append(e *dbt.Entity) error + Next(atype oam.AssetType, num int) ([]*dbt.Entity, error) +} + type SessionStats struct { sync.Mutex WorkItemsCompleted int `json:"workItemsCompleted"` @@ -46,6 +54,7 @@ type SessionManager interface { AddSession(s Session) error CancelSession(id uuid.UUID) GetSession(id uuid.UUID) Session + GetSessions() []Session Shutdown() }