diff --git a/Update.md b/Update.md deleted file mode 100644 index 2f0f2e5..0000000 --- a/Update.md +++ /dev/null @@ -1,16 +0,0 @@ -# Update - -So today I've succesfully generated a Go client for the nats-token-exchange api. The next step is to write the code that actually interfaces with it so that engines can be configured to authenticate automatically. The goal here is that you do so onme of the following: - -* Provide a client_id and client_secrate and the engine will: - * Auth with Auth0 and get a token - * Generate NKeys if required (or maybe always generate them? More secure?) - * Use that token to get a NATS token - * Authenticate to NATS - * Re-generate NATS and OAuth tokens as required -* Provide an OAuth token and do all of the above except authenticating with auth0 and getting a new oAuth token when it expires - * This model is designed for use when the tokens are managed by an enternal party like srcman - -One difficult question is going to be how do we test this... - -**UPDATE:** I have now added a token exchange container, however I need to fix this issue before it will start: https://github.com/overmindtech/nats-token-exchange/issues/17 \ No newline at end of file diff --git a/performance_test.go b/performance_test.go new file mode 100644 index 0000000..2b1affd --- /dev/null +++ b/performance_test.go @@ -0,0 +1,185 @@ +package discovery + +import ( + "context" + "math" + "sync" + "testing" + "time" + + "github.com/overmindtech/sdp-go" +) + +type SlowSource struct { + RequestDuration time.Duration +} + +func (s *SlowSource) Type() string { + return "person" +} + +func (s *SlowSource) Name() string { + return "slow-source" +} + +func (s *SlowSource) DefaultCacheDuration() time.Duration { + return 10 * time.Minute +} + +func (s *SlowSource) Contexts() []string { + return []string{"test"} +} + +func (s *SlowSource) Hidden() bool { + return false +} + +func (s *SlowSource) Get(ctx context.Context, itemContext string, query string) (*sdp.Item, error) { + end := time.Now().Add(s.RequestDuration) + attributes, _ := sdp.ToAttributes(map[string]interface{}{ + "name": query, + }) + + item := sdp.Item{ + Type: "person", + UniqueAttribute: "name", + Attributes: attributes, + Context: "test", + LinkedItemRequests: []*sdp.ItemRequest{}, + } + + for i := 0; i != 2; i++ { + item.LinkedItemRequests = append(item.LinkedItemRequests, &sdp.ItemRequest{ + Type: "person", + Method: sdp.RequestMethod_GET, + Query: RandomName(), + Context: "test", + }) + } + + time.Sleep(time.Until(end)) + + return &item, nil +} + +func (s *SlowSource) Find(ctx context.Context, itemContext string) ([]*sdp.Item, error) { + return []*sdp.Item{}, nil +} + +func (s *SlowSource) Weight() int { + return 100 +} + +func TestParallelRequestPerformance(t *testing.T) { + // This test is designed to ensure that request duration is linear up to a + // certain point. Above that point the overhead caused by having so many + // goroutines running will start to make the response times non-linear which + // maybe isn't ideal but given realistic loads we probably don't care. + t.Run("Without linking", func(t *testing.T) { + RunLinearPerformanceTest(t, "10 requests", 10, 0, 1) + RunLinearPerformanceTest(t, "100 requests", 100, 0, 10) + RunLinearPerformanceTest(t, "1,000 requests", 1000, 0, 100) + }) + + t.Run("With linking", func(t *testing.T) { + RunLinearPerformanceTest(t, "1 request 3 depth", 1, 3, 1) + RunLinearPerformanceTest(t, "1 request 3 depth", 1, 3, 100) + RunLinearPerformanceTest(t, "1 request 5 depth", 1, 5, 100) + RunLinearPerformanceTest(t, "10 requests 5 depth", 10, 5, 100) + }) +} + +// RunLinearPerformanceTest Runs a test with a given number in input requests, +// link depth and parallelisation limit. Expected results and expected duration +// are determined automatically meaning all this is testing for is the fact that +// the perfomance continues to be linear and predictable +func RunLinearPerformanceTest(t *testing.T, name string, numRequests int, linkDepth int, numParallel int) { + t.Helper() + + t.Run(name, func(t *testing.T) { + result := TimeRequests(numRequests, linkDepth, numParallel) + + if len(result.Results) != result.ExpectedItems { + t.Errorf("Expected %v items, got %v", result.ExpectedItems, len(result.Results)) + } + + if result.TimeTaken > result.MaxTime { + t.Errorf("Requests took too long: %v Max: %v", result.TimeTaken.String(), result.MaxTime.String()) + } + }) +} + +type TimedResults struct { + ExpectedItems int + MaxTime time.Duration + TimeTaken time.Duration + Results []*sdp.Item + Errors []error +} + +func TimeRequests(numRequests int, linkDepth int, numParallel int) TimedResults { + engine := Engine{ + Name: "performance-test", + MaxParallelExecutions: numParallel, + } + engine.AddSources(&SlowSource{ + RequestDuration: 100 * time.Millisecond, + }) + engine.Start() + defer engine.Stop() + + // Calculate how many items to expect and the expected duration + var expectedItems int + var expectedDuration time.Duration + for i := 0; i <= linkDepth; i++ { + thisLayer := int(math.Pow(2, float64(i))) * numRequests + thisDuration := 200 * math.Ceil(float64(thisLayer)/float64(numParallel)) + expectedDuration = expectedDuration + (time.Duration(thisDuration) * time.Millisecond) + expectedItems = expectedItems + thisLayer + } + + results := make([]*sdp.Item, 0) + errors := make([]error, 0) + resultsMutex := sync.Mutex{} + wg := sync.WaitGroup{} + + start := time.Now() + + for i := 0; i < numRequests; i++ { + rt := RequestTracker{ + Request: &sdp.ItemRequest{ + Type: "person", + Method: sdp.RequestMethod_GET, + Query: RandomName(), + Context: "test", + LinkDepth: uint32(linkDepth), + }, + Engine: &engine, + } + + wg.Add(1) + + go func(rt *RequestTracker) { + defer wg.Done() + + items, err := rt.Execute() + + resultsMutex.Lock() + results = append(results, items...) + if err != nil { + errors = append(errors, err) + } + resultsMutex.Unlock() + }(&rt) + } + + wg.Wait() + + return TimedResults{ + ExpectedItems: expectedItems, + MaxTime: expectedDuration, + TimeTaken: time.Since(start), + Results: results, + Errors: errors, + } +} diff --git a/request_tracker_test.go b/request_tracker_test.go index 056b931..2f69634 100644 --- a/request_tracker_test.go +++ b/request_tracker_test.go @@ -2,7 +2,6 @@ package discovery import ( "context" - "fmt" "sync" "testing" "time" @@ -86,96 +85,6 @@ func (s *SpeedTestSource) Weight() int { return 10 } -func TestExecuteParallel(t *testing.T) { - queryDelay := (200 * time.Millisecond) - numSources := 10 - sources := make([]Source, numSources) - - // Create a number of sources - for i := 0; i < len(sources); i++ { - sources[i] = &SpeedTestSource{ - QueryDelay: queryDelay, - ReturnType: fmt.Sprintf("type%v", i), - } - } - - t.Run("With no parallelism", func(t *testing.T) { - t.Parallel() - - engine := Engine{ - Name: "no-parallel", - MaxParallelExecutions: 1, - } - - engine.AddSources(sources...) - engine.SetupThrottle() - - tracker := RequestTracker{ - Engine: &engine, - Request: &sdp.ItemRequest{ - Type: "*", - Method: sdp.RequestMethod_FIND, - LinkDepth: 0, - Context: "*", - }, - } - - timeStart := time.Now() - - _, err := tracker.Execute() - - timeTaken := time.Since(timeStart) - - if err != nil { - t.Fatal(err) - } - - expectedTime := time.Duration(int64(queryDelay) * int64(numSources)) - - if timeTaken < expectedTime { - t.Errorf("Query with no parallelism took < %v. This means it must have run in parallel", expectedTime) - } - }) - - t.Run("With lots of parallelism", func(t *testing.T) { - t.Parallel() - - engine := Engine{ - Name: "no-parallel", - MaxParallelExecutions: 999, - } - - engine.AddSources(sources...) - engine.SetupThrottle() - - tracker := RequestTracker{ - Engine: &engine, - Request: &sdp.ItemRequest{ - Type: "*", - Method: sdp.RequestMethod_FIND, - LinkDepth: 0, - Context: "*", - }, - } - - timeStart := time.Now() - - _, err := tracker.Execute() - - timeTaken := time.Since(timeStart) - - if err != nil { - t.Fatal(err) - } - - expectedTime := (queryDelay * 2) // Double it give us some wiggle room - - if timeTaken > expectedTime { - t.Errorf("Query with no parallelism took %v which is > than the expected max of %v. This means it must not have run in parallel", timeTaken, expectedTime) - } - }) -} - func TestExecute(t *testing.T) { engine := Engine{ Name: "test", diff --git a/requests.go b/requests.go index 3cba64b..83f5e04 100644 --- a/requests.go +++ b/requests.go @@ -140,10 +140,31 @@ func (e *Engine) ExecuteRequest(ctx context.Context, req *sdp.ItemRequest) ([]*s } } + allItems := make([]*sdp.Item, 0) + allErrors := make([]error, 0) + + go func() { + for item := range items { + allItems = append(allItems, item) + } + done <- true + }() + + go func() { + for err := range errors { + allErrors = append(allErrors, err) + } + done <- true + }() + for request, sources := range expanded { wg.Add(1) + + e.throttle.Lock() + go func(r *sdp.ItemRequest, sources []Source) { defer wg.Done() + defer e.throttle.Unlock() var requestItems []*sdp.Item var requestError error @@ -188,23 +209,6 @@ func (e *Engine) ExecuteRequest(ctx context.Context, req *sdp.ItemRequest) ([]*s }(request, sources) } - allItems := make([]*sdp.Item, 0) - allErrors := make([]error, 0) - - go func() { - for item := range items { - allItems = append(allItems, item) - } - done <- true - }() - - go func() { - for err := range errors { - allErrors = append(allErrors, err) - } - done <- true - }() - // Wait for all requests to complete wg.Wait() diff --git a/shared_test.go b/shared_test.go index a59b1f8..cbd9137 100644 --- a/shared_test.go +++ b/shared_test.go @@ -12,19 +12,24 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -func randSeq(n int) string { - b := make([]rune, n) +func randString(length int) string { + var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + + b := make([]byte, length) for i := range b { - b[i] = letters[rand.Intn(len(letters))] + b[i] = charset[seededRand.Intn(len(charset))] } return string(b) } func RandomName() string { - n := namegenerator.NewNameGenerator(time.Now().UTC().UnixNano()) - return n.Generate() + " " + n.Generate() + "-" + randSeq(10) + seed := time.Now().UTC().UnixNano() + nameGenerator := namegenerator.NewNameGenerator(seed) + name := nameGenerator.Generate() + randGarbage := randString(10) + return fmt.Sprintf("%v-%v", name, randGarbage) } func (s *TestSource) NewTestItem(itemContext string, query string) *sdp.Item { diff --git a/source.go b/source.go index b5b93fb..7573eb6 100644 --- a/source.go +++ b/source.go @@ -3,7 +3,6 @@ package discovery import ( "context" "fmt" - "sync" "time" "github.com/overmindtech/sdp-go" @@ -161,7 +160,6 @@ func (e *Engine) Get(ctx context.Context, r *sdp.ItemRequest, relevantSources [] } } - e.throttle.Lock() log.WithFields(logFields).Debug("Executing get for backend") var getDuration time.Duration @@ -172,8 +170,6 @@ func (e *Engine) Get(ctx context.Context, r *sdp.ItemRequest, relevantSources [] item, err = src.Get(ctx, r.Context, r.Query) }) - e.throttle.Unlock() - logFields["itemFound"] = (err == nil) logFields["error"] = err @@ -235,9 +231,6 @@ func (e *Engine) Get(ctx context.Context, r *sdp.ItemRequest, relevantSources [] // results. Only returns an error if all sources fail, in which case returns the // first error func (e *Engine) Find(ctx context.Context, r *sdp.ItemRequest, relevantSources []Source) ([]*sdp.Item, error) { - var storageMutex sync.Mutex - var workingSources sync.WaitGroup - if len(relevantSources) == 0 { return nil, &sdp.ItemRequestError{ ErrorType: sdp.ItemRequestError_NOCONTEXT, @@ -253,123 +246,107 @@ func (e *Engine) Find(ctx context.Context, r *sdp.ItemRequest, relevantSources [ errors := make([]error, 0) for _, src := range relevantSources { - workingSources.Add(1) - go func(source Source) { - defer workingSources.Done() - - tags := sdpcache.Tags{ - "method": "find", - "sourceName": source.Name(), - "context": r.Context, - } + tags := sdpcache.Tags{ + "method": "find", + "sourceName": src.Name(), + "context": r.Context, + } - logFields := log.Fields{ - "sourceName": source.Name(), - "type": r.Type, - "context": r.Context, - } + logFields := log.Fields{ + "sourceName": src.Name(), + "type": r.Type, + "context": r.Context, + } + + if !r.IgnoreCache { + cachedItems, err := e.cache.Search(tags) + + switch err := err.(type) { + case sdpcache.CacheNotFoundError: + // If the item/error wasn't found in the cache then just + // continue on + case *sdp.ItemRequestError: + if err.ErrorType == sdp.ItemRequestError_NOTFOUND { + log.WithFields(logFields).Debug("Found cached empty FIND, not executing") - if !r.IgnoreCache { - cachedItems, err := e.cache.Search(tags) - - switch err := err.(type) { - case sdpcache.CacheNotFoundError: - // If the item/error wasn't found in the cache then just - // continue on - case *sdp.ItemRequestError: - if err.ErrorType == sdp.ItemRequestError_NOTFOUND { - log.WithFields(logFields).Debug("Found cached empty FIND, not executing") - - return - } - default: - // If we get a result from the cache then return that - if len(cachedItems) > 0 { - logFields["items"] = len(cachedItems) - - log.WithFields(logFields).Debug("Found items from cache") - - storageMutex.Lock() - items = append(items, cachedItems...) - errors = append(errors, err) - storageMutex.Unlock() - - return - } + continue } - } + default: + // If we get a result from the cache then return that + if len(cachedItems) > 0 { + logFields["items"] = len(cachedItems) - e.throttle.Lock() - log.WithFields(logFields).Debug("Executing find") + log.WithFields(logFields).Debug("Found items from cache") - finds := make([]*sdp.Item, 0) - var err error + items = append(items, cachedItems...) + errors = append(errors, err) - findDuration := timeOperation(func() { - finds, err = source.Find(ctx, r.Context) - }) + continue + } + } + } - e.throttle.Unlock() + log.WithFields(logFields).Debug("Executing find") - logFields["items"] = len(finds) - logFields["error"] = err + finds := make([]*sdp.Item, 0) + var err error - if err == nil { - log.WithFields(logFields).Debug("Find complete") + findDuration := timeOperation(func() { + finds, err = src.Find(ctx, r.Context) + }) - // Check too see if nothing was found, make sure we cache the - // nothing - if len(finds) == 0 { - e.cache.StoreError(&sdp.ItemRequestError{ - ErrorType: sdp.ItemRequestError_NOTFOUND, - }, GetCacheDuration(source), tags) - } - } else { - log.WithFields(logFields).Error("Error during find") + logFields["items"] = len(finds) + logFields["error"] = err - e.cache.StoreError(err, GetCacheDuration(source), tags) + if err == nil { + log.WithFields(logFields).Debug("Find complete") + + // Check too see if nothing was found, make sure we cache the + // nothing + if len(finds) == 0 { + e.cache.StoreError(&sdp.ItemRequestError{ + ErrorType: sdp.ItemRequestError_NOTFOUND, + }, GetCacheDuration(src), tags) } + } else { + log.WithFields(logFields).Error("Error during find") - // For each found item, add more details - // - // Use the index here to ensure that we're actually editing the - // right thing - for i := range finds { - // Get a pointer to the item we're dealing with - item := finds[i] + e.cache.StoreError(err, GetCacheDuration(src), tags) + } - // Handle the case where we are given a nil pointer - if item == nil { - continue - } + // For each found item, add more details + // + // Use the index here to ensure that we're actually editing the + // right thing + for i := range finds { + // Get a pointer to the item we're dealing with + item := finds[i] - // Store metadata - item.Metadata = &sdp.Metadata{ - Timestamp: timestamppb.New(time.Now()), - SourceDuration: durationpb.New(findDuration), - SourceDurationPerItem: durationpb.New(time.Duration(findDuration.Nanoseconds() / int64(len(finds)))), - SourceName: source.Name(), - } + // Handle the case where we are given a nil pointer + if item == nil { + continue + } - // Mark the item as hidden if the source is a hidden source - if hs, ok := source.(HiddenSource); ok { - item.Metadata.Hidden = hs.Hidden() - } + // Store metadata + item.Metadata = &sdp.Metadata{ + Timestamp: timestamppb.New(time.Now()), + SourceDuration: durationpb.New(findDuration), + SourceDurationPerItem: durationpb.New(time.Duration(findDuration.Nanoseconds() / int64(len(finds)))), + SourceName: src.Name(), + } - // Cache the item - e.cache.StoreItem(item, GetCacheDuration(source), tags) + // Mark the item as hidden if the source is a hidden source + if hs, ok := src.(HiddenSource); ok { + item.Metadata.Hidden = hs.Hidden() } - storageMutex.Lock() - items = append(items, finds...) - errors = append(errors, err) - storageMutex.Unlock() - }(src) - } + // Cache the item + e.cache.StoreItem(item, GetCacheDuration(src), tags) + } - workingSources.Wait() - storageMutex.Lock() - defer storageMutex.Unlock() + items = append(items, finds...) + errors = append(errors, err) + } // Check if there were any successful runs and if so return the items for _, e := range errors { @@ -389,9 +366,6 @@ func (e *Engine) Find(ctx context.Context, r *sdp.ItemRequest, relevantSources [ // results. Only returns an error if all sources fail, in which case returns the // first error func (e *Engine) Search(ctx context.Context, r *sdp.ItemRequest, relevantSources []Source) ([]*sdp.Item, error) { - var storageMutex sync.Mutex - var workingSources sync.WaitGroup - searchableSources := make([]SearchableSource, 0) // Filter further by searchability @@ -416,124 +390,108 @@ func (e *Engine) Search(ctx context.Context, r *sdp.ItemRequest, relevantSources errors := make([]error, 0) for _, src := range searchableSources { - workingSources.Add(1) - go func(source SearchableSource) { - defer workingSources.Done() - - tags := sdpcache.Tags{ - "method": "find", - "sourceName": source.Name(), - "query": r.Query, - "context": r.Context, - } + tags := sdpcache.Tags{ + "method": "find", + "sourceName": src.Name(), + "query": r.Query, + "context": r.Context, + } - logFields := log.Fields{ - "sourceName": source.Name(), - "type": r.Type, - "context": r.Context, - } + logFields := log.Fields{ + "sourceName": src.Name(), + "type": r.Type, + "context": r.Context, + } + + if !r.IgnoreCache { + cachedItems, err := e.cache.Search(tags) + + switch err := err.(type) { + case sdpcache.CacheNotFoundError: + // If the item/error wasn't found in the cache then just + // continue on + case *sdp.ItemRequestError: + if err.ErrorType == sdp.ItemRequestError_NOTFOUND { + log.WithFields(logFields).Debug("Found cached empty result, not executing") - if !r.IgnoreCache { - cachedItems, err := e.cache.Search(tags) - - switch err := err.(type) { - case sdpcache.CacheNotFoundError: - // If the item/error wasn't found in the cache then just - // continue on - case *sdp.ItemRequestError: - if err.ErrorType == sdp.ItemRequestError_NOTFOUND { - log.WithFields(logFields).Debug("Found cached empty result, not executing") - - return - } - default: - // If we get a result from the cache then return that - if len(cachedItems) > 0 { - logFields["items"] = len(cachedItems) - - log.WithFields(logFields).Debug("Found items from cache") - - storageMutex.Lock() - items = append(items, cachedItems...) - errors = append(errors, err) - storageMutex.Unlock() - - return - } + continue } - } + default: + // If we get a result from the cache then return that + if len(cachedItems) > 0 { + logFields["items"] = len(cachedItems) - e.throttle.Lock() - log.WithFields(logFields).Debug("Executing search") + log.WithFields(logFields).Debug("Found items from cache") - var searchItems []*sdp.Item - var err error + items = append(items, cachedItems...) + errors = append(errors, err) - searchDuration := timeOperation(func() { - searchItems, err = source.Search(ctx, r.Context, r.Query) - }) + continue + } + } + } - e.throttle.Unlock() + log.WithFields(logFields).Debug("Executing search") - logFields["items"] = len(searchItems) - logFields["error"] = err + var searchItems []*sdp.Item + var err error - if err == nil { - log.WithFields(logFields).Debug("Search completed") + searchDuration := timeOperation(func() { + searchItems, err = src.Search(ctx, r.Context, r.Query) + }) - // Check too see if nothing was found, make sure we cache the - // nothing - if len(searchItems) == 0 { - e.cache.StoreError(&sdp.ItemRequestError{ - ErrorType: sdp.ItemRequestError_NOTFOUND, - }, GetCacheDuration(source), tags) - } - } else { - log.WithFields(logFields).Error("Error during search") + logFields["items"] = len(searchItems) + logFields["error"] = err - e.cache.StoreError(err, GetCacheDuration(source), tags) + if err == nil { + log.WithFields(logFields).Debug("Search completed") + + // Check too see if nothing was found, make sure we cache the + // nothing + if len(searchItems) == 0 { + e.cache.StoreError(&sdp.ItemRequestError{ + ErrorType: sdp.ItemRequestError_NOTFOUND, + }, GetCacheDuration(src), tags) } + } else { + log.WithFields(logFields).Error("Error during search") - // For each found item, add more details - // - // Use the index here to ensure that we're actually editing the - // right thing - for i := range searchItems { - // Get a pointer to the item we're dealing with - item := searchItems[i] + e.cache.StoreError(err, GetCacheDuration(src), tags) + } - // Handle the case where we are given a nil pointer - if item == nil { - continue - } + // For each found item, add more details + // + // Use the index here to ensure that we're actually editing the + // right thing + for i := range searchItems { + // Get a pointer to the item we're dealing with + item := searchItems[i] - // Store metadata - item.Metadata = &sdp.Metadata{ - Timestamp: timestamppb.New(time.Now()), - SourceDuration: durationpb.New(searchDuration), - SourceDurationPerItem: durationpb.New(time.Duration(searchDuration.Nanoseconds() / int64(len(searchItems)))), - SourceName: source.Name(), - } + // Handle the case where we are given a nil pointer + if item == nil { + continue + } - // Mark the item as hidden if the source is a hidden source - if hs, ok := source.(HiddenSource); ok { - item.Metadata.Hidden = hs.Hidden() - } + // Store metadata + item.Metadata = &sdp.Metadata{ + Timestamp: timestamppb.New(time.Now()), + SourceDuration: durationpb.New(searchDuration), + SourceDurationPerItem: durationpb.New(time.Duration(searchDuration.Nanoseconds() / int64(len(searchItems)))), + SourceName: src.Name(), + } - // Cache the item - e.cache.StoreItem(item, GetCacheDuration(source), tags) + // Mark the item as hidden if the source is a hidden source + if hs, ok := src.(HiddenSource); ok { + item.Metadata.Hidden = hs.Hidden() } - storageMutex.Lock() - items = append(items, searchItems...) - errors = append(errors, err) - storageMutex.Unlock() - }(src) - } + // Cache the item + e.cache.StoreItem(item, GetCacheDuration(src), tags) + } - workingSources.Wait() - storageMutex.Lock() - defer storageMutex.Unlock() + items = append(items, searchItems...) + errors = append(errors, err) + } // Check if there were any successful runs and if so return the items for _, e := range errors {