Skip to content

Commit

Permalink
initial implementation of the new session entity queue
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Feb 24, 2025
1 parent 982122d commit 2720092
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 15 deletions.
89 changes: 78 additions & 11 deletions engine/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@ package dispatcher

import (
"errors"
"fmt"
"log/slog"
"os"
"time"

"github.com/caffix/queue"
et "github.com/owasp-amass/amass/v4/engine/types"
oam "github.com/owasp-amass/open-asset-model"
)

const (
MinPipelineQueueSize = 100
MaxPipelineQueueSize = 500
)

type dis struct {
Expand All @@ -35,7 +42,7 @@ func NewDispatcher(l *slog.Logger, r et.Registry, mgr et.SessionManager) et.Disp
completed: queue.NewQueue(),
}

go d.collectEvents()
go d.maintainPipelines()
return d
}

Expand All @@ -48,19 +55,23 @@ func (d *dis) Shutdown() {
close(d.done)
}

func (d *dis) collectEvents() {
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
func (d *dis) maintainPipelines() {
ctick := time.NewTicker(5 * time.Second)
defer ctick.Stop()
qtick := time.NewTicker(time.Second)
defer qtick.Stop()
loop:
for {
select {
case <-d.done:
break loop
case <-qtick.C:
d.fillPipelineQueues()
case <-d.completed.Signal():
if element, ok := d.completed.Next(); ok {
d.completedCallback(element)
}
case <-t.C:
case <-ctick.C:
if element, ok := d.completed.Next(); ok {
d.completedCallback(element)
}
Expand All @@ -69,6 +80,43 @@ loop:
d.completed.Process(d.completedCallback)
}

func (d *dis) fillPipelineQueues() {
sessions := d.mgr.GetSessions()
if len(sessions) == 0 {
return
}

var ptypes []oam.AssetType
for _, atype := range oam.AssetList {
if ap, err := d.reg.GetPipeline(atype); err == nil {
if ap.Queue.Len() < MinPipelineQueueSize {
ptypes = append(ptypes, atype)
}
}
}

numRequested := MaxPipelineQueueSize / len(sessions)
for _, s := range sessions {
if s == nil || s.Done() {
continue
}
for _, atype := range ptypes {
if entities, err := s.Queue().Next(atype, numRequested); err == nil && len(entities) > 0 {
for _, entity := range entities {
event := &et.Event{
Name: fmt.Sprintf("%s - %s", string(atype), entity.Asset.Key()),
Entity: entity,
Session: s,
}
if err := d.appendToPipelineQueue(event); err != nil {
s.Log().WithGroup("event").With("name", event.Name).Error(err.Error())
}
}
}
}
}
}

func (d *dis) completedCallback(data interface{}) {
ede, ok := data.(*et.EventDataElement)
if !ok {
Expand All @@ -95,21 +143,40 @@ func (d *dis) DispatchEvent(e *et.Event) error {
} else if e.Entity == nil || e.Entity.Asset == nil {
return errors.New("the event has no associated entity or asset")
}
// do not schedule the same asset more than once
set := e.Session.EventSet()
if set.Has(e.Entity.ID) {
return errors.New("this event was processed previously")
}

ap, err := d.reg.GetPipeline(e.Entity.Asset.AssetType())
if err != nil {
return err
}

e.Dispatcher = d
// do not schedule the same asset more than once
set := e.Session.EventSet()
if set.Has(e.Entity.ID) {
return errors.New("this event was processed previously")
if qlen := ap.Queue.Len(); e.Meta != nil || qlen < MinPipelineQueueSize {
if err := d.appendToPipelineQueue(e); err != nil {
return err
}
return nil
}

return e.Session.Queue().Append(e.Entity)
}

func (d *dis) appendToPipelineQueue(e *et.Event) error {
if e == nil || e.Session == nil || e.Entity == nil || e.Entity.Asset == nil {
return errors.New("the event is nil")
}
set.Insert(e.Entity.ID)

ap, err := d.reg.GetPipeline(e.Entity.Asset.AssetType())
if err != nil {
return err
}

e.Dispatcher = d
if data := et.NewEventDataElement(e); data != nil {
e.Session.EventSet().Insert(e.Entity.ID)
data.Queue = d.completed
ap.Queue.Append(data)
// increment the number of events processed in the session
Expand Down
12 changes: 12 additions & 0 deletions engine/sessions/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ func (r *manager) CancelSession(id uuid.UUID) {
delete(r.sessions, id)
}

func (r *manager) GetSessions() []et.Session {
r.RLock()
defer r.RUnlock()

sessions := make([]et.Session, 0, len(r.sessions))
for _, s := range r.sessions {
sessions = append(sessions, s)
}

return sessions
}

// GetSession: returns a session from a session storage.
func (r *manager) GetSession(id uuid.UUID) et.Session {
r.RLock()
Expand Down
89 changes: 85 additions & 4 deletions engine/sessions/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/caffix/stringset"
Expand All @@ -24,6 +25,8 @@ import (
"github.com/owasp-amass/asset-db/repository"
"github.com/owasp-amass/asset-db/repository/neo4j"
"github.com/owasp-amass/asset-db/repository/sqlrepo"
dbt "github.com/owasp-amass/asset-db/types"
oam "github.com/owasp-amass/open-asset-model"
"github.com/yl2chen/cidranger"
)

Expand All @@ -34,9 +37,10 @@ type Session struct {
cfg *config.Config
scope *scope.Scope
db repository.Repository
queue *sessionQueue
dsn string
dbtype string
c *cache.Cache
cache *cache.Cache
ranger cidranger.Ranger
tmpdir string
stats *et.SessionStats
Expand Down Expand Up @@ -80,10 +84,15 @@ func CreateSession(cfg *config.Config) (et.Session, error) {
return nil, err
}

s.c, err = cache.New(c, s.db, time.Minute)
if err != nil || s.c == nil {
s.cache, err = cache.New(c, s.db, time.Minute)
if err != nil || s.cache == nil {
return nil, errors.New("failed to create the session cache")
}

s.queue = newSessionQueue(s)
s.log.Info("Session initialized")
s.log.Info("Temporary directory created", slog.String("dir", s.tmpdir))
s.log.Info("Database connection established", slog.String("dsn", s.dsn))
return s, nil
}

Expand Down Expand Up @@ -112,7 +121,11 @@ func (s *Session) DB() repository.Repository {
}

func (s *Session) Cache() *cache.Cache {
return s.c
return s.cache
}

func (s *Session) Queue() et.SessionQueue {
return s.queue
}

func (s *Session) CIDRanger() cidranger.Ranger {
Expand Down Expand Up @@ -221,3 +234,71 @@ func (s *Session) createFileRepo(fname string) (repository.Repository, error) {
}
return c, nil
}

type sessionQueue struct {
sync.Mutex
session *Session
q map[string][]string
}

func newSessionQueue(s *Session) *sessionQueue {
return &sessionQueue{
session: s,
q: make(map[string][]string),
}
}

func (sq *sessionQueue) Append(e *dbt.Entity) error {
sq.Lock()
defer sq.Unlock()

if e == nil {
return errors.New("entity is nil")
}
if e.Asset == nil {
return errors.New("asset is nil")
}

key := string(e.Asset.AssetType())
if key == "" {
return errors.New("asset type is empty")
}
if _, found := sq.q[key]; !found {
sq.q[key] = make([]string, 0)
}
if e.ID == "" {
return errors.New("entity ID is empty")
}

sq.q[key] = append(sq.q[key], e.ID)
return nil
}

func (sq *sessionQueue) Next(atype oam.AssetType, num int) ([]*dbt.Entity, error) {
var ids []string
key := string(atype)

sq.Lock()
if q, found := sq.q[key]; found {
if len(q) > num {
ids = q[:num]
sq.q[key] = q[num:]
} else {
ids = q
delete(sq.q, key)
}
}
sq.Unlock()

var results []*dbt.Entity
for _, id := range ids {
if e, err := sq.session.Cache().FindEntityById(id); err == nil {
results = append(results, e)
}
}

if len(results) == 0 {
return nil, errors.New("no entities found")
}
return results, nil
}
9 changes: 9 additions & 0 deletions engine/types/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/owasp-amass/amass/v4/engine/sessions/scope"
"github.com/owasp-amass/asset-db/cache"
"github.com/owasp-amass/asset-db/repository"
dbt "github.com/owasp-amass/asset-db/types"
oam "github.com/owasp-amass/open-asset-model"
"github.com/yl2chen/cidranger"
)

Expand All @@ -27,6 +29,7 @@ type Session interface {
Scope() *scope.Scope
DB() repository.Repository
Cache() *cache.Cache
Queue() SessionQueue
CIDRanger() cidranger.Ranger
TmpDir() string
Stats() *SessionStats
Expand All @@ -35,6 +38,11 @@ type Session interface {
Kill()
}

type SessionQueue interface {
Append(e *dbt.Entity) error
Next(atype oam.AssetType, num int) ([]*dbt.Entity, error)
}

type SessionStats struct {
sync.Mutex
WorkItemsCompleted int `json:"workItemsCompleted"`
Expand All @@ -46,6 +54,7 @@ type SessionManager interface {
AddSession(s Session) error
CancelSession(id uuid.UUID)
GetSession(id uuid.UUID) Session
GetSessions() []Session
Shutdown()
}

Expand Down

0 comments on commit 2720092

Please sign in to comment.