Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
refactor: add contexts support
Browse files Browse the repository at this point in the history
  • Loading branch information
thde committed Oct 11, 2023
1 parent 0ab5b77 commit c9bb807
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
31 changes: 16 additions & 15 deletions internal/zenduty/zenduty.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zenduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -24,7 +25,7 @@ const (
)

type Client struct {
credentials func() (username string, password string)
credentials func(context.Context) (username string, password string)
http *http.Client
baseURL *url.URL
logger *slog.Logger
Expand Down Expand Up @@ -72,7 +73,7 @@ func NewLogger(options LoggerOptions) *slog.Logger {

// NewClient returns a new zenduty client which can be modified by passing
// options
func NewClient(credentials func() (string, string), opts ...ClientOption) *Client {
func NewClient(credentials func(context.Context) (string, string), opts ...ClientOption) *Client {
c := &Client{
credentials: credentials,
baseURL: defaultBaseURL(),
Expand Down Expand Up @@ -102,7 +103,7 @@ func Logger(logger *slog.Logger) ClientOption {
}

// Login executes a login with the given username and password
func (c *Client) Login() error {
func (c *Client) Login(ctx context.Context) error {
if c.http.Jar == nil {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
Expand All @@ -120,7 +121,7 @@ func (c *Client) Login() error {
return fmt.Errorf("error getting login page")
}

username, password := c.credentials()
username, password := c.credentials(ctx)
body := new(bytes.Buffer)
if err := json.NewEncoder(body).Encode(loginRequest{Email: username, Password: password}); err != nil {
return fmt.Errorf("can not encode login body: %w", err)
Expand Down Expand Up @@ -150,7 +151,7 @@ func (c *Client) doLoggedIn(req *http.Request, obj interface{}) error {
}

if !c.isLoggedIn() {
err := c.Login()
err := c.Login(req.Context())
if err != nil {
return err
}
Expand Down Expand Up @@ -299,9 +300,9 @@ func newScheduleFrom(data io.Reader) (*Schedule, error) {
return &Schedule{Calendar: cal}, nil
}

func (c *Client) listTeams() (teamList, error) {
func (c *Client) listTeams(ctx context.Context) (teamList, error) {
url := fmt.Sprintf("%s/api/account/teams", c.baseURL)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("can't create request to list teams: %w", err)
}
Expand All @@ -312,9 +313,9 @@ func (c *Client) listTeams() (teamList, error) {
return teamsResp, nil
}

func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) {
func (c *Client) listSchedules(ctx context.Context, teamID string) ([]apiSchedule, error) {
url := fmt.Sprintf("%s/api/account/teams/%s/schedules", c.baseURL, teamID)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request for listing team schedules: %w", err)
}
Expand All @@ -328,20 +329,20 @@ func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) {

// CombinedSchedule returns the full combined schedule of all team schedules where the given
// user email is part of
func (c *Client) CombinedSchedule(email string) (*Schedule, error) {
teams, err := c.listTeams()
func (c *Client) CombinedSchedule(ctx context.Context, email string) (*Schedule, error) {
teams, err := c.listTeams(ctx)
if err != nil {
return nil, err
}
teamsOfUser := teams.teamsForUser(email)
combined := ics.NewCalendarFor("zenduty-oncall")
for _, team := range teamsOfUser {
schedules, err := c.listSchedules(team.ID)
schedules, err := c.listSchedules(ctx, team.ID)
if err != nil {
return nil, err
}
for _, schedule := range schedules {
calendar, err := c.GetSchedule(team.ID, schedule.ID, amountMonths)
calendar, err := c.GetSchedule(ctx, team.ID, schedule.ID, amountMonths)
if err != nil {
return nil, err
}
Expand All @@ -361,9 +362,9 @@ func (c *Client) CombinedSchedule(email string) (*Schedule, error) {
return &Schedule{Calendar: combined}, nil
}

func (c *Client) GetSchedule(teamID, scheduleID string, months int) (*Schedule, error) {
func (c *Client) GetSchedule(ctx context.Context, teamID, scheduleID string, months int) (*Schedule, error) {
url := fmt.Sprintf("%s/api/account/teams/%s/schedules/%s/get_schedule_ics/?months=%d&is_team_or_user=1", c.baseURL, teamID, scheduleID, months)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request schedule %q of team %q: %w", scheduleID, teamID, err)
}
Expand Down
7 changes: 4 additions & 3 deletions internal/zenduty/zenduty_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zenduty

import (
"context"
"os"
"testing"
"time"
Expand All @@ -25,12 +26,12 @@ func TestICS(t *testing.T) {
// options.Level = slog.LevelDebug
logger := NewLogger(options)
z := NewClient(
func() (string, string) { return username, password },
func(context.Context) (string, string) { return username, password },
Logger(logger),
)
is.NoErr(z.Login())
is.NoErr(z.Login(context.Background()))

calendar, err := z.CombinedSchedule(username)
calendar, err := z.CombinedSchedule(context.Background(), username)
is.NoErr(err)
calendar = calendar.OnlyAttendees(username)
is.True(len(calendar.Events()) > 0)
Expand Down
11 changes: 6 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -33,10 +34,10 @@ func run(out io.Writer) error {
loggerOpts.Out = out
logger := zenduty.NewLogger(loggerOpts)
z := zenduty.NewClient(
func() (string, string) { return username, password },
func(context.Context) (string, string) { return username, password },
zenduty.Logger(logger),
)
if err := z.Login(); err != nil {
if err := z.Login(context.Background()); err != nil {
return err
}

Expand Down Expand Up @@ -75,9 +76,9 @@ func run(out io.Writer) error {
return server.ListenAndServe()
}

func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle {
func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(ctx context.Context, teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
schedule, err := getSchedule(ps.ByName(teamKey), ps.ByName(scheduleKey), 12)
schedule, err := getSchedule(r.Context(), ps.ByName(teamKey), ps.ByName(scheduleKey), 12)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
Expand All @@ -89,7 +90,7 @@ func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(t

func myScheduleHandler(c *zenduty.Client, forUser func(httprouter.Params) string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
schedule, err := c.CombinedSchedule(forUser(ps))
schedule, err := c.CombinedSchedule(r.Context(), forUser(ps))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
Expand Down

0 comments on commit c9bb807

Please sign in to comment.