From c9bb807826986c11f8909a06b9daba63e7c9d421 Mon Sep 17 00:00:00 2001 From: thde Date: Wed, 11 Oct 2023 10:34:17 +0200 Subject: [PATCH] refactor: add contexts support --- internal/zenduty/zenduty.go | 31 ++++++++++++++++--------------- internal/zenduty/zenduty_test.go | 7 ++++--- main.go | 11 ++++++----- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/internal/zenduty/zenduty.go b/internal/zenduty/zenduty.go index 3c4fdee..33e5bca 100644 --- a/internal/zenduty/zenduty.go +++ b/internal/zenduty/zenduty.go @@ -2,6 +2,7 @@ package zenduty import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -24,7 +25,7 @@ const ( ) type Client struct { - credentials func() (username string, password string) + credentials func(context.Context) (username string, password string) http *http.Client baseURL *url.URL logger *slog.Logger @@ -72,7 +73,7 @@ func NewLogger(options LoggerOptions) *slog.Logger { // NewClient returns a new zenduty client which can be modified by passing // options -func NewClient(credentials func() (string, string), opts ...ClientOption) *Client { +func NewClient(credentials func(context.Context) (string, string), opts ...ClientOption) *Client { c := &Client{ credentials: credentials, baseURL: defaultBaseURL(), @@ -102,7 +103,7 @@ func Logger(logger *slog.Logger) ClientOption { } // Login executes a login with the given username and password -func (c *Client) Login() error { +func (c *Client) Login(ctx context.Context) error { if c.http.Jar == nil { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { @@ -120,7 +121,7 @@ func (c *Client) Login() error { return fmt.Errorf("error getting login page") } - username, password := c.credentials() + username, password := c.credentials(ctx) body := new(bytes.Buffer) if err := json.NewEncoder(body).Encode(loginRequest{Email: username, Password: password}); err != nil { return fmt.Errorf("can not encode login body: %w", err) @@ -150,7 +151,7 @@ func (c *Client) doLoggedIn(req *http.Request, obj interface{}) error { } if !c.isLoggedIn() { - err := c.Login() + err := c.Login(req.Context()) if err != nil { return err } @@ -299,9 +300,9 @@ func newScheduleFrom(data io.Reader) (*Schedule, error) { return &Schedule{Calendar: cal}, nil } -func (c *Client) listTeams() (teamList, error) { +func (c *Client) listTeams(ctx context.Context) (teamList, error) { url := fmt.Sprintf("%s/api/account/teams", c.baseURL) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("can't create request to list teams: %w", err) } @@ -312,9 +313,9 @@ func (c *Client) listTeams() (teamList, error) { return teamsResp, nil } -func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) { +func (c *Client) listSchedules(ctx context.Context, teamID string) ([]apiSchedule, error) { url := fmt.Sprintf("%s/api/account/teams/%s/schedules", c.baseURL, teamID) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("error creating request for listing team schedules: %w", err) } @@ -328,20 +329,20 @@ func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) { // CombinedSchedule returns the full combined schedule of all team schedules where the given // user email is part of -func (c *Client) CombinedSchedule(email string) (*Schedule, error) { - teams, err := c.listTeams() +func (c *Client) CombinedSchedule(ctx context.Context, email string) (*Schedule, error) { + teams, err := c.listTeams(ctx) if err != nil { return nil, err } teamsOfUser := teams.teamsForUser(email) combined := ics.NewCalendarFor("zenduty-oncall") for _, team := range teamsOfUser { - schedules, err := c.listSchedules(team.ID) + schedules, err := c.listSchedules(ctx, team.ID) if err != nil { return nil, err } for _, schedule := range schedules { - calendar, err := c.GetSchedule(team.ID, schedule.ID, amountMonths) + calendar, err := c.GetSchedule(ctx, team.ID, schedule.ID, amountMonths) if err != nil { return nil, err } @@ -361,9 +362,9 @@ func (c *Client) CombinedSchedule(email string) (*Schedule, error) { return &Schedule{Calendar: combined}, nil } -func (c *Client) GetSchedule(teamID, scheduleID string, months int) (*Schedule, error) { +func (c *Client) GetSchedule(ctx context.Context, teamID, scheduleID string, months int) (*Schedule, error) { url := fmt.Sprintf("%s/api/account/teams/%s/schedules/%s/get_schedule_ics/?months=%d&is_team_or_user=1", c.baseURL, teamID, scheduleID, months) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("error creating request schedule %q of team %q: %w", scheduleID, teamID, err) } diff --git a/internal/zenduty/zenduty_test.go b/internal/zenduty/zenduty_test.go index 963d6e8..69d1edd 100644 --- a/internal/zenduty/zenduty_test.go +++ b/internal/zenduty/zenduty_test.go @@ -1,6 +1,7 @@ package zenduty import ( + "context" "os" "testing" "time" @@ -25,12 +26,12 @@ func TestICS(t *testing.T) { // options.Level = slog.LevelDebug logger := NewLogger(options) z := NewClient( - func() (string, string) { return username, password }, + func(context.Context) (string, string) { return username, password }, Logger(logger), ) - is.NoErr(z.Login()) + is.NoErr(z.Login(context.Background())) - calendar, err := z.CombinedSchedule(username) + calendar, err := z.CombinedSchedule(context.Background(), username) is.NoErr(err) calendar = calendar.OnlyAttendees(username) is.True(len(calendar.Events()) > 0) diff --git a/main.go b/main.go index f406e4e..607b22b 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "net/http" @@ -33,10 +34,10 @@ func run(out io.Writer) error { loggerOpts.Out = out logger := zenduty.NewLogger(loggerOpts) z := zenduty.NewClient( - func() (string, string) { return username, password }, + func(context.Context) (string, string) { return username, password }, zenduty.Logger(logger), ) - if err := z.Login(); err != nil { + if err := z.Login(context.Background()); err != nil { return err } @@ -75,9 +76,9 @@ func run(out io.Writer) error { return server.ListenAndServe() } -func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle { +func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(ctx context.Context, teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - schedule, err := getSchedule(ps.ByName(teamKey), ps.ByName(scheduleKey), 12) + schedule, err := getSchedule(r.Context(), ps.ByName(teamKey), ps.ByName(scheduleKey), 12) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) @@ -89,7 +90,7 @@ func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(t func myScheduleHandler(c *zenduty.Client, forUser func(httprouter.Params) string) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - schedule, err := c.CombinedSchedule(forUser(ps)) + schedule, err := c.CombinedSchedule(r.Context(), forUser(ps)) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error()))