diff --git a/.circleci/config.yml b/.circleci/config.yml index 970c892..a6a78bb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 jobs: build: docker: - - image: golang:1.11 + - image: golang:1.18 working_directory: /projects/dataloaden steps: &steps - checkout diff --git a/dataloaden.go b/dataloaden.go index 3419286..4bc751d 100644 --- a/dataloaden.go +++ b/dataloaden.go @@ -1,28 +1,217 @@ -package main +package dataloaden import ( - "fmt" - "os" - - "github.com/vektah/dataloaden/pkg/generator" + "sync" + "time" ) -func main() { - if len(os.Args) != 4 { - fmt.Println("usage: name keyType valueType") - fmt.Println(" example:") - fmt.Println(" dataloaden 'UserLoader int []*github.com/my/package.User'") - os.Exit(1) +// LoaderConfig captures the config to create a new Loader +type LoaderConfig[K comparable, T any] struct { + // Fetch is a method that provides the data for the loader + Fetch func(keys []K) ([]T, []error) + + // Wait is how long wait before sending a batch + Wait time.Duration + + // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit + MaxBatch int +} + +// NewLoader creates a new Loader given a fetch, wait, and maxBatch +func NewLoader[K comparable, T any](config LoaderConfig[K, T]) *Loader[K, T] { + return &Loader[K, T]{ + fetch: config.Fetch, + wait: config.Wait, + maxBatch: config.MaxBatch, } +} + +// Loader batches and caches requests +type Loader[K comparable, T any] struct { + // this method provides the data for the loader + fetch func(keys []K) ([]T, []error) + + // how long to done before sending a batch + wait time.Duration + + // this will limit the maximum number of keys to send in one batch, 0 = no limit + maxBatch int + + // INTERNAL + + // lazily created cache + cache map[K]T + + // the current batch. keys will continue to be collected until timeout is hit, + // then everything will be sent to the fetch method and out to the listeners + batch *loaderBatch[K, T] + + // mutex to prevent races + mu sync.Mutex +} + +type loaderBatch[K comparable, T any] struct { + keys []K + data []T + error []error + closing bool + done chan struct{} +} + +// Load a User by key, batching and caching will be applied automatically +func (l *Loader[K, T]) Load(key K) (T, error) { + return l.LoadThunk(key)() +} + +// LoadThunk returns a function that when called will block waiting for a User. +// This method should be used if you want one goroutine to make requests to many +// different data loaders without blocking until the thunk is called. +func (l *Loader[K, T]) LoadThunk(key K) func() (T, error) { + l.mu.Lock() + if it, ok := l.cache[key]; ok { + l.mu.Unlock() + return func() (T, error) { + return it, nil + } + } + if l.batch == nil { + l.batch = &loaderBatch[K, T]{done: make(chan struct{})} + } + batch := l.batch + pos := batch.keyIndex(l, key) + l.mu.Unlock() + + return func() (T, error) { + <-batch.done + + var data T + if pos < len(batch.data) { + data = batch.data[pos] + } + + var err error + // its convenient to be able to return a single error for everything + if len(batch.error) == 1 { + err = batch.error[0] + } else if batch.error != nil { + err = batch.error[pos] + } + + if err == nil { + l.mu.Lock() + l.unsafeSet(key, data) + l.mu.Unlock() + } + + return data, err + } +} + +// LoadAll fetches many keys at once. It will be broken into appropriate sized +// sub batches depending on how the loader is configured +func (l *Loader[K, T]) LoadAll(keys []K) ([]T, []error) { + results := make([]func() (T, error), len(keys)) - wd, err := os.Getwd() - if err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(2) + for i, key := range keys { + results[i] = l.LoadThunk(key) } - if err := generator.Generate(os.Args[1], os.Args[2], os.Args[3], wd); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(2) + users := make([]T, len(keys)) + errors := make([]error, len(keys)) + for i, thunk := range results { + users[i], errors[i] = thunk() } + return users, errors +} + +// LoadAllThunk returns a function that when called will block waiting for a Users. +// This method should be used if you want one goroutine to make requests to many +// different data loaders without blocking until the thunk is called. +func (l *Loader[K, T]) LoadAllThunk(keys []K) func() ([]T, []error) { + results := make([]func() (T, error), len(keys)) + for i, key := range keys { + results[i] = l.LoadThunk(key) + } + return func() ([]T, []error) { + users := make([]T, len(keys)) + errors := make([]error, len(keys)) + for i, thunk := range results { + users[i], errors[i] = thunk() + } + return users, errors + } +} + +// Prime the cache with the provided key and value. If the key already exists, no change is made +// and false is returned. +// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) +func (l *Loader[K, T]) Prime(key K, value T) bool { + l.mu.Lock() + var found bool + if _, found = l.cache[key]; !found { + l.unsafeSet(key, value) + } + l.mu.Unlock() + return !found +} + +// Clear the value at key from the cache, if it exists +func (l *Loader[K, T]) Clear(key K) { + l.mu.Lock() + delete(l.cache, key) + l.mu.Unlock() +} + +func (l *Loader[K, T]) unsafeSet(key K, value T) { + if l.cache == nil { + l.cache = map[K]T{} + } + l.cache[key] = value +} + +// keyIndex will return the location of the key in the batch, if its not found +// it will add the key to the batch +func (b *loaderBatch[K, T]) keyIndex(l *Loader[K, T], key K) int { + for i, existingKey := range b.keys { + if key == existingKey { + return i + } + } + + pos := len(b.keys) + b.keys = append(b.keys, key) + if pos == 0 { + go b.startTimer(l) + } + + if l.maxBatch != 0 && pos >= l.maxBatch-1 { + if !b.closing { + b.closing = true + l.batch = nil + go b.end(l) + } + } + + return pos +} + +func (b *loaderBatch[K, T]) startTimer(l *Loader[K, T]) { + time.Sleep(l.wait) + l.mu.Lock() + + // we must have hit a batch limit and are already finalizing this batch + if b.closing { + l.mu.Unlock() + return + } + + l.batch = nil + l.mu.Unlock() + + b.end(l) +} + +func (b *loaderBatch[K, T]) end(l *Loader[K, T]) { + b.data, b.error = l.fetch(b.keys) + close(b.done) } diff --git a/example/benchmark_test.go b/example/benchmark_test.go deleted file mode 100644 index 657d418..0000000 --- a/example/benchmark_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package example - -import ( - "fmt" - "math/rand" - "strconv" - "sync" - "testing" - "time" -) - -func BenchmarkLoader(b *testing.B) { - dl := &UserLoader{ - wait: 500 * time.Nanosecond, - maxBatch: 100, - fetch: func(keys []string) ([]*User, []error) { - users := make([]*User, len(keys)) - errors := make([]error, len(keys)) - - for i, key := range keys { - if rand.Int()%100 == 1 { - errors[i] = fmt.Errorf("user not found") - } else if rand.Int()%100 == 1 { - users[i] = nil - } else { - users[i] = &User{ID: key, Name: "user " + key} - } - } - return users, errors - }, - } - - b.Run("caches", func(b *testing.B) { - thunks := make([]func() (*User, error), b.N) - for i := 0; i < b.N; i++ { - thunks[i] = dl.LoadThunk(strconv.Itoa(rand.Int() % 300)) - } - - for i := 0; i < b.N; i++ { - thunks[i]() - } - }) - - b.Run("random spread", func(b *testing.B) { - thunks := make([]func() (*User, error), b.N) - for i := 0; i < b.N; i++ { - thunks[i] = dl.LoadThunk(strconv.Itoa(rand.Int())) - } - - for i := 0; i < b.N; i++ { - thunks[i]() - } - }) - - b.Run("concurently", func(b *testing.B) { - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - for j := 0; j < b.N; j++ { - dl.Load(strconv.Itoa(rand.Int())) - } - wg.Done() - }() - } - wg.Wait() - }) -} diff --git a/example/pkgname/user.go b/example/pkgname/user.go deleted file mode 100644 index f9a4bf5..0000000 --- a/example/pkgname/user.go +++ /dev/null @@ -1,3 +0,0 @@ -package differentpkg - -//go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User diff --git a/example/pkgname/userloader_gen.go b/example/pkgname/userloader_gen.go deleted file mode 100644 index 3495d73..0000000 --- a/example/pkgname/userloader_gen.go +++ /dev/null @@ -1,224 +0,0 @@ -// Code generated by github.com/vektah/dataloaden, DO NOT EDIT. - -package differentpkg - -import ( - "sync" - "time" - - "github.com/vektah/dataloaden/example" -) - -// UserLoaderConfig captures the config to create a new UserLoader -type UserLoaderConfig struct { - // Fetch is a method that provides the data for the loader - Fetch func(keys []string) ([]*example.User, []error) - - // Wait is how long wait before sending a batch - Wait time.Duration - - // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit - MaxBatch int -} - -// NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch -func NewUserLoader(config UserLoaderConfig) *UserLoader { - return &UserLoader{ - fetch: config.Fetch, - wait: config.Wait, - maxBatch: config.MaxBatch, - } -} - -// UserLoader batches and caches requests -type UserLoader struct { - // this method provides the data for the loader - fetch func(keys []string) ([]*example.User, []error) - - // how long to done before sending a batch - wait time.Duration - - // this will limit the maximum number of keys to send in one batch, 0 = no limit - maxBatch int - - // INTERNAL - - // lazily created cache - cache map[string]*example.User - - // the current batch. keys will continue to be collected until timeout is hit, - // then everything will be sent to the fetch method and out to the listeners - batch *userLoaderBatch - - // mutex to prevent races - mu sync.Mutex -} - -type userLoaderBatch struct { - keys []string - data []*example.User - error []error - closing bool - done chan struct{} -} - -// Load a User by key, batching and caching will be applied automatically -func (l *UserLoader) Load(key string) (*example.User, error) { - return l.LoadThunk(key)() -} - -// LoadThunk returns a function that when called will block waiting for a User. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserLoader) LoadThunk(key string) func() (*example.User, error) { - l.mu.Lock() - if it, ok := l.cache[key]; ok { - l.mu.Unlock() - return func() (*example.User, error) { - return it, nil - } - } - if l.batch == nil { - l.batch = &userLoaderBatch{done: make(chan struct{})} - } - batch := l.batch - pos := batch.keyIndex(l, key) - l.mu.Unlock() - - return func() (*example.User, error) { - <-batch.done - - var data *example.User - if pos < len(batch.data) { - data = batch.data[pos] - } - - var err error - // its convenient to be able to return a single error for everything - if len(batch.error) == 1 { - err = batch.error[0] - } else if batch.error != nil { - err = batch.error[pos] - } - - if err == nil { - l.mu.Lock() - l.unsafeSet(key, data) - l.mu.Unlock() - } - - return data, err - } -} - -// LoadAll fetches many keys at once. It will be broken into appropriate sized -// sub batches depending on how the loader is configured -func (l *UserLoader) LoadAll(keys []string) ([]*example.User, []error) { - results := make([]func() (*example.User, error), len(keys)) - - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - - users := make([]*example.User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors -} - -// LoadAllThunk returns a function that when called will block waiting for a Users. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*example.User, []error) { - results := make([]func() (*example.User, error), len(keys)) - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - return func() ([]*example.User, []error) { - users := make([]*example.User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors - } -} - -// Prime the cache with the provided key and value. If the key already exists, no change is made -// and false is returned. -// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) -func (l *UserLoader) Prime(key string, value *example.User) bool { - l.mu.Lock() - var found bool - if _, found = l.cache[key]; !found { - // make a copy when writing to the cache, its easy to pass a pointer in from a loop var - // and end up with the whole cache pointing to the same value. - cpy := *value - l.unsafeSet(key, &cpy) - } - l.mu.Unlock() - return !found -} - -// Clear the value at key from the cache, if it exists -func (l *UserLoader) Clear(key string) { - l.mu.Lock() - delete(l.cache, key) - l.mu.Unlock() -} - -func (l *UserLoader) unsafeSet(key string, value *example.User) { - if l.cache == nil { - l.cache = map[string]*example.User{} - } - l.cache[key] = value -} - -// keyIndex will return the location of the key in the batch, if its not found -// it will add the key to the batch -func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { - for i, existingKey := range b.keys { - if key == existingKey { - return i - } - } - - pos := len(b.keys) - b.keys = append(b.keys, key) - if pos == 0 { - go b.startTimer(l) - } - - if l.maxBatch != 0 && pos >= l.maxBatch-1 { - if !b.closing { - b.closing = true - l.batch = nil - go b.end(l) - } - } - - return pos -} - -func (b *userLoaderBatch) startTimer(l *UserLoader) { - time.Sleep(l.wait) - l.mu.Lock() - - // we must have hit a batch limit and are already finalizing this batch - if b.closing { - l.mu.Unlock() - return - } - - l.batch = nil - l.mu.Unlock() - - b.end(l) -} - -func (b *userLoaderBatch) end(l *UserLoader) { - b.data, b.error = l.fetch(b.keys) - close(b.done) -} diff --git a/example/slice/user.go b/example/slice/user.go index 767f2c1..c422522 100644 --- a/example/slice/user.go +++ b/example/slice/user.go @@ -1,19 +1,18 @@ -//go:generate go run github.com/vektah/dataloaden UserSliceLoader int []github.com/vektah/dataloaden/example.User - package slice import ( "strconv" "time" + "github.com/vektah/dataloaden" "github.com/vektah/dataloaden/example" ) -func NewLoader() *UserSliceLoader { - return &UserSliceLoader{ - wait: 2 * time.Millisecond, - maxBatch: 100, - fetch: func(keys []int) ([][]example.User, []error) { +func NewLoader() *dataloaden.Loader[int, []example.User] { + return dataloaden.NewLoader(dataloaden.LoaderConfig[int, []example.User]{ + Wait: 2 * time.Millisecond, + MaxBatch: 100, + Fetch: func(keys []int) ([][]example.User, []error) { users := make([][]example.User, len(keys)) errors := make([]error, len(keys)) @@ -22,5 +21,5 @@ func NewLoader() *UserSliceLoader { } return users, errors }, - } + }) } diff --git a/example/slice/usersliceloader_gen.go b/example/slice/usersliceloader_gen.go deleted file mode 100644 index c2d6e83..0000000 --- a/example/slice/usersliceloader_gen.go +++ /dev/null @@ -1,225 +0,0 @@ -// Code generated by github.com/vektah/dataloaden, DO NOT EDIT. - -package slice - -import ( - "sync" - "time" - - "github.com/vektah/dataloaden/example" -) - -// UserSliceLoaderConfig captures the config to create a new UserSliceLoader -type UserSliceLoaderConfig struct { - // Fetch is a method that provides the data for the loader - Fetch func(keys []int) ([][]example.User, []error) - - // Wait is how long wait before sending a batch - Wait time.Duration - - // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit - MaxBatch int -} - -// NewUserSliceLoader creates a new UserSliceLoader given a fetch, wait, and maxBatch -func NewUserSliceLoader(config UserSliceLoaderConfig) *UserSliceLoader { - return &UserSliceLoader{ - fetch: config.Fetch, - wait: config.Wait, - maxBatch: config.MaxBatch, - } -} - -// UserSliceLoader batches and caches requests -type UserSliceLoader struct { - // this method provides the data for the loader - fetch func(keys []int) ([][]example.User, []error) - - // how long to done before sending a batch - wait time.Duration - - // this will limit the maximum number of keys to send in one batch, 0 = no limit - maxBatch int - - // INTERNAL - - // lazily created cache - cache map[int][]example.User - - // the current batch. keys will continue to be collected until timeout is hit, - // then everything will be sent to the fetch method and out to the listeners - batch *userSliceLoaderBatch - - // mutex to prevent races - mu sync.Mutex -} - -type userSliceLoaderBatch struct { - keys []int - data [][]example.User - error []error - closing bool - done chan struct{} -} - -// Load a User by key, batching and caching will be applied automatically -func (l *UserSliceLoader) Load(key int) ([]example.User, error) { - return l.LoadThunk(key)() -} - -// LoadThunk returns a function that when called will block waiting for a User. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserSliceLoader) LoadThunk(key int) func() ([]example.User, error) { - l.mu.Lock() - if it, ok := l.cache[key]; ok { - l.mu.Unlock() - return func() ([]example.User, error) { - return it, nil - } - } - if l.batch == nil { - l.batch = &userSliceLoaderBatch{done: make(chan struct{})} - } - batch := l.batch - pos := batch.keyIndex(l, key) - l.mu.Unlock() - - return func() ([]example.User, error) { - <-batch.done - - var data []example.User - if pos < len(batch.data) { - data = batch.data[pos] - } - - var err error - // its convenient to be able to return a single error for everything - if len(batch.error) == 1 { - err = batch.error[0] - } else if batch.error != nil { - err = batch.error[pos] - } - - if err == nil { - l.mu.Lock() - l.unsafeSet(key, data) - l.mu.Unlock() - } - - return data, err - } -} - -// LoadAll fetches many keys at once. It will be broken into appropriate sized -// sub batches depending on how the loader is configured -func (l *UserSliceLoader) LoadAll(keys []int) ([][]example.User, []error) { - results := make([]func() ([]example.User, error), len(keys)) - - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - - users := make([][]example.User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors -} - -// LoadAllThunk returns a function that when called will block waiting for a Users. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserSliceLoader) LoadAllThunk(keys []int) func() ([][]example.User, []error) { - results := make([]func() ([]example.User, error), len(keys)) - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - return func() ([][]example.User, []error) { - users := make([][]example.User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors - } -} - -// Prime the cache with the provided key and value. If the key already exists, no change is made -// and false is returned. -// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) -func (l *UserSliceLoader) Prime(key int, value []example.User) bool { - l.mu.Lock() - var found bool - if _, found = l.cache[key]; !found { - // make a copy when writing to the cache, its easy to pass a pointer in from a loop var - // and end up with the whole cache pointing to the same value. - cpy := make([]example.User, len(value)) - copy(cpy, value) - l.unsafeSet(key, cpy) - } - l.mu.Unlock() - return !found -} - -// Clear the value at key from the cache, if it exists -func (l *UserSliceLoader) Clear(key int) { - l.mu.Lock() - delete(l.cache, key) - l.mu.Unlock() -} - -func (l *UserSliceLoader) unsafeSet(key int, value []example.User) { - if l.cache == nil { - l.cache = map[int][]example.User{} - } - l.cache[key] = value -} - -// keyIndex will return the location of the key in the batch, if its not found -// it will add the key to the batch -func (b *userSliceLoaderBatch) keyIndex(l *UserSliceLoader, key int) int { - for i, existingKey := range b.keys { - if key == existingKey { - return i - } - } - - pos := len(b.keys) - b.keys = append(b.keys, key) - if pos == 0 { - go b.startTimer(l) - } - - if l.maxBatch != 0 && pos >= l.maxBatch-1 { - if !b.closing { - b.closing = true - l.batch = nil - go b.end(l) - } - } - - return pos -} - -func (b *userSliceLoaderBatch) startTimer(l *UserSliceLoader) { - time.Sleep(l.wait) - l.mu.Lock() - - // we must have hit a batch limit and are already finalizing this batch - if b.closing { - l.mu.Unlock() - return - } - - l.batch = nil - l.mu.Unlock() - - b.end(l) -} - -func (b *userSliceLoaderBatch) end(l *UserSliceLoader) { - b.data, b.error = l.fetch(b.keys) - close(b.done) -} diff --git a/example/slice/usersliceloader_test.go b/example/slice/usersliceloader_test.go index 857b197..fc5eae3 100644 --- a/example/slice/usersliceloader_test.go +++ b/example/slice/usersliceloader_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vektah/dataloaden" "github.com/vektah/dataloaden/example" ) @@ -16,10 +16,10 @@ func TestUserLoader(t *testing.T) { var fetches [][]int var mu sync.Mutex - dl := &UserSliceLoader{ - wait: 10 * time.Millisecond, - maxBatch: 5, - fetch: func(keys []int) (users [][]example.User, errors []error) { + dl := dataloaden.NewLoader(dataloaden.LoaderConfig[int, []example.User]{ + Wait: 10 * time.Millisecond, + MaxBatch: 5, + Fetch: func(keys []int) (users [][]example.User, errors []error) { mu.Lock() fetches = append(fetches, keys) mu.Unlock() @@ -39,7 +39,7 @@ func TestUserLoader(t *testing.T) { } return users, errors }, - } + }) t.Run("fetch concurrent data", func(t *testing.T) { t.Run("load user successfully", func(t *testing.T) { @@ -86,8 +86,8 @@ func TestUserLoader(t *testing.T) { defer mu.Unlock() require.Len(t, fetches, 2) - assert.Len(t, fetches[0], 5) - assert.Len(t, fetches[1], 3) + require.Len(t, fetches[0], 5) + require.Len(t, fetches[1], 3) }) t.Run("fetch more", func(t *testing.T) { diff --git a/example/user.go b/example/user.go index 24d2863..f3a8114 100644 --- a/example/user.go +++ b/example/user.go @@ -1,9 +1,9 @@ -//go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User - package example import ( "time" + + "github.com/vektah/dataloaden" ) // User is some kind of database backed model @@ -14,11 +14,11 @@ type User struct { // NewLoader will collect user requests for 2 milliseconds and send them as a single batch to the fetch func // normally fetch would be a database call. -func NewLoader() *UserLoader { - return &UserLoader{ - wait: 2 * time.Millisecond, - maxBatch: 100, - fetch: func(keys []string) ([]*User, []error) { +func NewLoader() *dataloaden.Loader[string, *User] { + return dataloaden.NewLoader(dataloaden.LoaderConfig[string, *User]{ + Wait: 2 * time.Millisecond, + MaxBatch: 100, + Fetch: func(keys []string) ([]*User, []error) { users := make([]*User, len(keys)) errors := make([]error, len(keys)) @@ -27,5 +27,5 @@ func NewLoader() *UserLoader { } return users, errors }, - } + }) } diff --git a/example/user_test.go b/example/user_test.go index 342a4ad..5b9cf3f 100644 --- a/example/user_test.go +++ b/example/user_test.go @@ -7,18 +7,18 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vektah/dataloaden" ) func TestUserLoader(t *testing.T) { var fetches [][]string var mu sync.Mutex - dl := &UserLoader{ - wait: 10 * time.Millisecond, - maxBatch: 5, - fetch: func(keys []string) ([]*User, []error) { + dl := dataloaden.NewLoader(dataloaden.LoaderConfig[string, *User]{ + Wait: 10 * time.Millisecond, + MaxBatch: 5, + Fetch: func(keys []string) ([]*User, []error) { mu.Lock() fetches = append(fetches, keys) mu.Unlock() @@ -35,7 +35,7 @@ func TestUserLoader(t *testing.T) { } return users, errors }, - } + }) t.Run("fetch concurrent data", func(t *testing.T) { t.Run("load user successfully", func(t *testing.T) { @@ -81,8 +81,8 @@ func TestUserLoader(t *testing.T) { defer mu.Unlock() require.Len(t, fetches, 2) - assert.Len(t, fetches[0], 5) - assert.Len(t, fetches[1], 3) + require.Len(t, fetches[0], 5) + require.Len(t, fetches[1], 3) }) t.Run("fetch more", func(t *testing.T) { @@ -153,7 +153,9 @@ func TestUserLoader(t *testing.T) { {ID: "Omega", Name: "Omega"}, } for _, user := range users { - dl.Prime(user.ID, &user) + u := &user + cpy := *u + dl.Prime(user.ID, &cpy) } u, err := dl.Load("Alpha") diff --git a/example/userloader_gen.go b/example/userloader_gen.go deleted file mode 100644 index 470ba6a..0000000 --- a/example/userloader_gen.go +++ /dev/null @@ -1,222 +0,0 @@ -// Code generated by github.com/vektah/dataloaden, DO NOT EDIT. - -package example - -import ( - "sync" - "time" -) - -// UserLoaderConfig captures the config to create a new UserLoader -type UserLoaderConfig struct { - // Fetch is a method that provides the data for the loader - Fetch func(keys []string) ([]*User, []error) - - // Wait is how long wait before sending a batch - Wait time.Duration - - // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit - MaxBatch int -} - -// NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch -func NewUserLoader(config UserLoaderConfig) *UserLoader { - return &UserLoader{ - fetch: config.Fetch, - wait: config.Wait, - maxBatch: config.MaxBatch, - } -} - -// UserLoader batches and caches requests -type UserLoader struct { - // this method provides the data for the loader - fetch func(keys []string) ([]*User, []error) - - // how long to done before sending a batch - wait time.Duration - - // this will limit the maximum number of keys to send in one batch, 0 = no limit - maxBatch int - - // INTERNAL - - // lazily created cache - cache map[string]*User - - // the current batch. keys will continue to be collected until timeout is hit, - // then everything will be sent to the fetch method and out to the listeners - batch *userLoaderBatch - - // mutex to prevent races - mu sync.Mutex -} - -type userLoaderBatch struct { - keys []string - data []*User - error []error - closing bool - done chan struct{} -} - -// Load a User by key, batching and caching will be applied automatically -func (l *UserLoader) Load(key string) (*User, error) { - return l.LoadThunk(key)() -} - -// LoadThunk returns a function that when called will block waiting for a User. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserLoader) LoadThunk(key string) func() (*User, error) { - l.mu.Lock() - if it, ok := l.cache[key]; ok { - l.mu.Unlock() - return func() (*User, error) { - return it, nil - } - } - if l.batch == nil { - l.batch = &userLoaderBatch{done: make(chan struct{})} - } - batch := l.batch - pos := batch.keyIndex(l, key) - l.mu.Unlock() - - return func() (*User, error) { - <-batch.done - - var data *User - if pos < len(batch.data) { - data = batch.data[pos] - } - - var err error - // its convenient to be able to return a single error for everything - if len(batch.error) == 1 { - err = batch.error[0] - } else if batch.error != nil { - err = batch.error[pos] - } - - if err == nil { - l.mu.Lock() - l.unsafeSet(key, data) - l.mu.Unlock() - } - - return data, err - } -} - -// LoadAll fetches many keys at once. It will be broken into appropriate sized -// sub batches depending on how the loader is configured -func (l *UserLoader) LoadAll(keys []string) ([]*User, []error) { - results := make([]func() (*User, error), len(keys)) - - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - - users := make([]*User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors -} - -// LoadAllThunk returns a function that when called will block waiting for a Users. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*User, []error) { - results := make([]func() (*User, error), len(keys)) - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - return func() ([]*User, []error) { - users := make([]*User, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - users[i], errors[i] = thunk() - } - return users, errors - } -} - -// Prime the cache with the provided key and value. If the key already exists, no change is made -// and false is returned. -// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) -func (l *UserLoader) Prime(key string, value *User) bool { - l.mu.Lock() - var found bool - if _, found = l.cache[key]; !found { - // make a copy when writing to the cache, its easy to pass a pointer in from a loop var - // and end up with the whole cache pointing to the same value. - cpy := *value - l.unsafeSet(key, &cpy) - } - l.mu.Unlock() - return !found -} - -// Clear the value at key from the cache, if it exists -func (l *UserLoader) Clear(key string) { - l.mu.Lock() - delete(l.cache, key) - l.mu.Unlock() -} - -func (l *UserLoader) unsafeSet(key string, value *User) { - if l.cache == nil { - l.cache = map[string]*User{} - } - l.cache[key] = value -} - -// keyIndex will return the location of the key in the batch, if its not found -// it will add the key to the batch -func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { - for i, existingKey := range b.keys { - if key == existingKey { - return i - } - } - - pos := len(b.keys) - b.keys = append(b.keys, key) - if pos == 0 { - go b.startTimer(l) - } - - if l.maxBatch != 0 && pos >= l.maxBatch-1 { - if !b.closing { - b.closing = true - l.batch = nil - go b.end(l) - } - } - - return pos -} - -func (b *userLoaderBatch) startTimer(l *UserLoader) { - time.Sleep(l.wait) - l.mu.Lock() - - // we must have hit a batch limit and are already finalizing this batch - if b.closing { - l.mu.Unlock() - return - } - - l.batch = nil - l.mu.Unlock() - - b.end(l) -} - -func (b *userLoaderBatch) end(l *UserLoader) { - b.data, b.error = l.fetch(b.keys) - close(b.done) -} diff --git a/go.mod b/go.mod index ba56ca0..d8c694a 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,10 @@ module github.com/vektah/dataloaden +go 1.18 + +require github.com/stretchr/testify v1.2.1 + require ( github.com/davecgh/go-spew v1.1.0 // indirect - github.com/pkg/errors v0.8.1 github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.2.1 - golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd ) diff --git a/go.sum b/go.sum index a350afb..314bd6a 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,6 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.2.1 h1:52QO5WkIUcHGIR7EnGagH88x1bUzqGXTC5/1bDTUQ7U= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd h1:oMEQDWVXVNpceQoVd1JN3CQ7LYJJzs5qWqZIUcxXHHw= -golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/pkg/generator/generator.go b/pkg/generator/generator.go deleted file mode 100644 index ff618e7..0000000 --- a/pkg/generator/generator.go +++ /dev/null @@ -1,163 +0,0 @@ -package generator - -import ( - "bytes" - "fmt" - "io/ioutil" - "path/filepath" - "regexp" - "strings" - "unicode" - - "github.com/pkg/errors" - "golang.org/x/tools/go/packages" - "golang.org/x/tools/imports" -) - -type templateData struct { - Package string - Name string - KeyType *goType - ValType *goType -} - -type goType struct { - Modifiers string - ImportPath string - ImportName string - Name string -} - -func (t *goType) String() string { - if t.ImportName != "" { - return t.Modifiers + t.ImportName + "." + t.Name - } - - return t.Modifiers + t.Name -} - -func (t *goType) IsPtr() bool { - return strings.HasPrefix(t.Modifiers, "*") -} - -func (t *goType) IsSlice() bool { - return strings.HasPrefix(t.Modifiers, "[]") -} - -var partsRe = regexp.MustCompile(`^([\[\]\*]*)(.*?)(\.\w*)?$`) - -func parseType(str string) (*goType, error) { - parts := partsRe.FindStringSubmatch(str) - if len(parts) != 4 { - return nil, fmt.Errorf("type must be in the form []*github.com/import/path.Name") - } - - t := &goType{ - Modifiers: parts[1], - ImportPath: parts[2], - Name: strings.TrimPrefix(parts[3], "."), - } - - if t.Name == "" { - t.Name = t.ImportPath - t.ImportPath = "" - } - - if t.ImportPath != "" { - p, err := packages.Load(&packages.Config{Mode: packages.NeedName}, t.ImportPath) - if err != nil { - return nil, err - } - if len(p) != 1 { - return nil, fmt.Errorf("not found") - } - - t.ImportName = p[0].Name - } - - return t, nil -} - -func Generate(name string, keyType string, valueType string, wd string) error { - data, err := getData(name, keyType, valueType, wd) - if err != nil { - return err - } - - filename := strings.ToLower(data.Name) + "_gen.go" - - if err := writeTemplate(filepath.Join(wd, filename), data); err != nil { - return err - } - - return nil -} - -func getData(name string, keyType string, valueType string, wd string) (templateData, error) { - var data templateData - - genPkg := getPackage(wd) - if genPkg == nil { - return templateData{}, fmt.Errorf("unable to find package info for " + wd) - } - - var err error - data.Name = name - data.Package = genPkg.Name - data.KeyType, err = parseType(keyType) - if err != nil { - return templateData{}, fmt.Errorf("key type: %s", err.Error()) - } - data.ValType, err = parseType(valueType) - if err != nil { - return templateData{}, fmt.Errorf("key type: %s", err.Error()) - } - - // if we are inside the same package as the type we don't need an import and can refer directly to the type - if genPkg.PkgPath == data.ValType.ImportPath { - data.ValType.ImportName = "" - data.ValType.ImportPath = "" - } - if genPkg.PkgPath == data.KeyType.ImportPath { - data.KeyType.ImportName = "" - data.KeyType.ImportPath = "" - } - - return data, nil -} - -func getPackage(dir string) *packages.Package { - p, _ := packages.Load(&packages.Config{ - Dir: dir, - }, ".") - - if len(p) != 1 { - return nil - } - - return p[0] -} - -func writeTemplate(filepath string, data templateData) error { - var buf bytes.Buffer - if err := tpl.Execute(&buf, data); err != nil { - return errors.Wrap(err, "generating code") - } - - src, err := imports.Process(filepath, buf.Bytes(), nil) - if err != nil { - return errors.Wrap(err, "unable to gofmt") - } - - if err := ioutil.WriteFile(filepath, src, 0644); err != nil { - return errors.Wrap(err, "writing output") - } - - return nil -} - -func lcFirst(s string) string { - r := []rune(s) - r[0] = unicode.ToLower(r[0]) - return string(r) -} diff --git a/pkg/generator/generator_test.go b/pkg/generator/generator_test.go deleted file mode 100644 index ee8d2fe..0000000 --- a/pkg/generator/generator_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package generator - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseType(t *testing.T) { - require.Equal(t, &goType{Name: "string"}, parse("string")) - require.Equal(t, &goType{Name: "Time", ImportPath: "time", ImportName: "time"}, parse("time.Time")) - require.Equal(t, &goType{ - Name: "Foo", - ImportPath: "github.com/vektah/dataloaden/pkg/generator/testdata/mismatch", - ImportName: "mismatched", - }, parse("github.com/vektah/dataloaden/pkg/generator/testdata/mismatch.Foo")) -} - -func parse(s string) *goType { - t, err := parseType(s) - if err != nil { - panic(err) - } - - return t -} diff --git a/pkg/generator/template.go b/pkg/generator/template.go deleted file mode 100644 index 48f5ba2..0000000 --- a/pkg/generator/template.go +++ /dev/null @@ -1,245 +0,0 @@ -package generator - -import "text/template" - -var tpl = template.Must(template.New("generated"). - Funcs(template.FuncMap{ - "lcFirst": lcFirst, - }). - Parse(` -// Code generated by github.com/vektah/dataloaden, DO NOT EDIT. - -package {{.Package}} - -import ( - "sync" - "time" - - {{if .KeyType.ImportPath}}"{{.KeyType.ImportPath}}"{{end}} - {{if .ValType.ImportPath}}"{{.ValType.ImportPath}}"{{end}} -) - -// {{.Name}}Config captures the config to create a new {{.Name}} -type {{.Name}}Config struct { - // Fetch is a method that provides the data for the loader - Fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) - - // Wait is how long wait before sending a batch - Wait time.Duration - - // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit - MaxBatch int -} - -// New{{.Name}} creates a new {{.Name}} given a fetch, wait, and maxBatch -func New{{.Name}}(config {{.Name}}Config) *{{.Name}} { - return &{{.Name}}{ - fetch: config.Fetch, - wait: config.Wait, - maxBatch: config.MaxBatch, - } -} - -// {{.Name}} batches and caches requests -type {{.Name}} struct { - // this method provides the data for the loader - fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) - - // how long to done before sending a batch - wait time.Duration - - // this will limit the maximum number of keys to send in one batch, 0 = no limit - maxBatch int - - // INTERNAL - - // lazily created cache - cache map[{{.KeyType.String}}]{{.ValType.String}} - - // the current batch. keys will continue to be collected until timeout is hit, - // then everything will be sent to the fetch method and out to the listeners - batch *{{.Name|lcFirst}}Batch - - // mutex to prevent races - mu sync.Mutex -} - -type {{.Name|lcFirst}}Batch struct { - keys []{{.KeyType}} - data []{{.ValType.String}} - error []error - closing bool - done chan struct{} -} - -// Load a {{.ValType.Name}} by key, batching and caching will be applied automatically -func (l *{{.Name}}) Load(key {{.KeyType.String}}) ({{.ValType.String}}, error) { - return l.LoadThunk(key)() -} - -// LoadThunk returns a function that when called will block waiting for a {{.ValType.Name}}. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *{{.Name}}) LoadThunk(key {{.KeyType.String}}) func() ({{.ValType.String}}, error) { - l.mu.Lock() - if it, ok := l.cache[key]; ok { - l.mu.Unlock() - return func() ({{.ValType.String}}, error) { - return it, nil - } - } - if l.batch == nil { - l.batch = &{{.Name|lcFirst}}Batch{done: make(chan struct{})} - } - batch := l.batch - pos := batch.keyIndex(l, key) - l.mu.Unlock() - - return func() ({{.ValType.String}}, error) { - <-batch.done - - var data {{.ValType.String}} - if pos < len(batch.data) { - data = batch.data[pos] - } - - var err error - // its convenient to be able to return a single error for everything - if len(batch.error) == 1 { - err = batch.error[0] - } else if batch.error != nil { - err = batch.error[pos] - } - - if err == nil { - l.mu.Lock() - l.unsafeSet(key, data) - l.mu.Unlock() - } - - return data, err - } -} - -// LoadAll fetches many keys at once. It will be broken into appropriate sized -// sub batches depending on how the loader is configured -func (l *{{.Name}}) LoadAll(keys []{{.KeyType}}) ([]{{.ValType.String}}, []error) { - results := make([]func() ({{.ValType.String}}, error), len(keys)) - - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - - {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() - } - return {{.ValType.Name|lcFirst}}s, errors -} - -// LoadAllThunk returns a function that when called will block waiting for a {{.ValType.Name}}s. -// This method should be used if you want one goroutine to make requests to many -// different data loaders without blocking until the thunk is called. -func (l *{{.Name}}) LoadAllThunk(keys []{{.KeyType}}) (func() ([]{{.ValType.String}}, []error)) { - results := make([]func() ({{.ValType.String}}, error), len(keys)) - for i, key := range keys { - results[i] = l.LoadThunk(key) - } - return func() ([]{{.ValType.String}}, []error) { - {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) - errors := make([]error, len(keys)) - for i, thunk := range results { - {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() - } - return {{.ValType.Name|lcFirst}}s, errors - } -} - -// Prime the cache with the provided key and value. If the key already exists, no change is made -// and false is returned. -// (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) -func (l *{{.Name}}) Prime(key {{.KeyType}}, value {{.ValType.String}}) bool { - l.mu.Lock() - var found bool - if _, found = l.cache[key]; !found { - {{- if .ValType.IsPtr }} - // make a copy when writing to the cache, its easy to pass a pointer in from a loop var - // and end up with the whole cache pointing to the same value. - cpy := *value - l.unsafeSet(key, &cpy) - {{- else if .ValType.IsSlice }} - // make a copy when writing to the cache, its easy to pass a pointer in from a loop var - // and end up with the whole cache pointing to the same value. - cpy := make({{.ValType.String}}, len(value)) - copy(cpy, value) - l.unsafeSet(key, cpy) - {{- else }} - l.unsafeSet(key, value) - {{- end }} - } - l.mu.Unlock() - return !found -} - -// Clear the value at key from the cache, if it exists -func (l *{{.Name}}) Clear(key {{.KeyType}}) { - l.mu.Lock() - delete(l.cache, key) - l.mu.Unlock() -} - -func (l *{{.Name}}) unsafeSet(key {{.KeyType}}, value {{.ValType.String}}) { - if l.cache == nil { - l.cache = map[{{.KeyType}}]{{.ValType.String}}{} - } - l.cache[key] = value -} - -// keyIndex will return the location of the key in the batch, if its not found -// it will add the key to the batch -func (b *{{.Name|lcFirst}}Batch) keyIndex(l *{{.Name}}, key {{.KeyType}}) int { - for i, existingKey := range b.keys { - if key == existingKey { - return i - } - } - - pos := len(b.keys) - b.keys = append(b.keys, key) - if pos == 0 { - go b.startTimer(l) - } - - if l.maxBatch != 0 && pos >= l.maxBatch-1 { - if !b.closing { - b.closing = true - l.batch = nil - go b.end(l) - } - } - - return pos -} - -func (b *{{.Name|lcFirst}}Batch) startTimer(l *{{.Name}}) { - time.Sleep(l.wait) - l.mu.Lock() - - // we must have hit a batch limit and are already finalizing this batch - if b.closing { - l.mu.Unlock() - return - } - - l.batch = nil - l.mu.Unlock() - - b.end(l) -} - -func (b *{{.Name|lcFirst}}Batch) end(l *{{.Name}}) { - b.data, b.error = l.fetch(b.keys) - close(b.done) -} -`)) diff --git a/pkg/generator/testdata/mismatch/mismatch.go b/pkg/generator/testdata/mismatch/mismatch.go deleted file mode 100644 index 79c8ba2..0000000 --- a/pkg/generator/testdata/mismatch/mismatch.go +++ /dev/null @@ -1,5 +0,0 @@ -package mismatched - -type Foo struct { - Name string -}