From 6636d2e212473022a126c4e837edc6e488cefbfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Santamaria?= Date: Fri, 24 May 2024 23:35:52 +0200 Subject: [PATCH] Add SetGlobalRequestLimit and SetChatRequestLimit --- api.go | 2 +- network.go | 44 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/api.go b/api.go index 02fe990..de7e501 100644 --- a/api.go +++ b/api.go @@ -29,7 +29,7 @@ import ( type API struct { token string base string - client client + client *client } // NewAPI returns a new API object. diff --git a/network.go b/network.go index 1c3aa0a..3ee854c 100644 --- a/network.go +++ b/network.go @@ -29,6 +29,7 @@ import ( "net/url" "path/filepath" "strings" + "sync" "time" "golang.org/x/time/rate" @@ -36,30 +37,55 @@ import ( type client struct { *http.Client - cl map[string]*rate.Limiter // chat based limiter - gl *rate.Limiter // global limiter + *sync.RWMutex + cl map[string]*rate.Limiter // chat based limiter + gl *rate.Limiter // global limiter + climiter func() *rate.Limiter } var lclient = newClient() -func newClient() client { - return client{ - Client: &http.Client{}, - cl: make(map[string]*rate.Limiter), - gl: rate.NewLimiter(rate.Every(time.Second/30), 10), +// SetGlobalRequestLimit sets the global frequency of requests to the Telegram API. +func SetGlobalRequestLimit(d time.Duration) { + lclient.Lock() + lclient.gl = rate.NewLimiter(rate.Every(d), 10) + lclient.Unlock() +} + +// SetChatRequestLimit sets the per-chat frequency of requests to the Telegram API. +func SetChatRequestLimit(d time.Duration) { + lclient.Lock() + lclient.cl = make(map[string]*rate.Limiter) + lclient.climiter = func() *rate.Limiter { + return rate.NewLimiter(rate.Every(d), 1) + } + lclient.Unlock() +} + +func newClient() *client { + return &client{ + Client: new(http.Client), + RWMutex: new(sync.RWMutex), + cl: make(map[string]*rate.Limiter), + gl: rate.NewLimiter(rate.Every(time.Second/30), 10), + climiter: func() *rate.Limiter { + return rate.NewLimiter(rate.Every(time.Minute/20), 1) + }, } } func (c client) wait(chatID string) error { - var ctx = context.Background() + c.RLock() + defer c.RUnlock() + ctx := context.Background() // If the chatID is empty, it's a general API call like GetUpdates, GetMe // and similar, so skip the per-chat request limit wait. if chatID != "" { // If no limiter exists for a chat, create one. l, ok := c.cl[chatID] if !ok { - l = rate.NewLimiter(rate.Every(time.Minute/20), 1) + l = c.climiter() c.cl[chatID] = l }