diff --git a/.gitignore b/.gitignore index 849ddff..51eb69f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ dist/ +.env diff --git a/go.mod b/go.mod index bd73686..3588962 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,9 @@ go 1.20 require ( github.com/arran4/golang-ical v0.1.0 github.com/julienschmidt/httprouter v1.3.0 - github.com/stretchr/testify v1.8.4 + github.com/matryer/is v1.4.1 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/net v0.16.0 ) -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +require github.com/stretchr/testify v1.8.4 // indirect diff --git a/go.sum b/go.sum index 324568f..dda7c76 100644 --- a/go.sum +++ b/go.sum @@ -8,9 +8,9 @@ github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4d github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ= +github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -23,9 +23,7 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqR golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/zenduty/zenduty.go b/internal/zenduty/zenduty.go index 6eb03ea..33e5bca 100644 --- a/internal/zenduty/zenduty.go +++ b/internal/zenduty/zenduty.go @@ -2,6 +2,7 @@ package zenduty import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -24,9 +25,10 @@ const ( ) type Client struct { - http *http.Client - baseURL *url.URL - logger *slog.Logger + credentials func(context.Context) (username string, password string) + http *http.Client + baseURL *url.URL + logger *slog.Logger } type loginRequest struct { @@ -71,11 +73,12 @@ func NewLogger(options LoggerOptions) *slog.Logger { // NewClient returns a new zenduty client which can be modified by passing // options -func NewClient(opts ...ClientOption) *Client { +func NewClient(credentials func(context.Context) (string, string), opts ...ClientOption) *Client { c := &Client{ - baseURL: defaultBaseURL(), - http: defaultHTTPClient(), - logger: NewLogger(LoggerOptions{}), + credentials: credentials, + baseURL: defaultBaseURL(), + http: defaultHTTPClient(), + logger: NewLogger(LoggerOptions{}), } for _, opt := range opts { opt(c) @@ -85,14 +88,14 @@ func NewClient(opts ...ClientOption) *Client { type ClientOption func(c *Client) -// BaseURL sets the base URL of the zenduty client +// BaseURL sets the base URL of the Zenduty client func BaseURL(u *url.URL) ClientOption { return func(c *Client) { c.baseURL = u } } -// Logger sets the logger of the zenduty client +// Logger sets the logger of the Zenduty client func Logger(logger *slog.Logger) ClientOption { return func(c *Client) { c.logger = logger @@ -100,7 +103,7 @@ func Logger(logger *slog.Logger) ClientOption { } // Login executes a login with the given username and password -func (c *Client) Login(username, password string) error { +func (c *Client) Login(ctx context.Context) error { if c.http.Jar == nil { jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) if err != nil { @@ -118,6 +121,7 @@ func (c *Client) Login(username, password string) error { return fmt.Errorf("error getting login page") } + username, password := c.credentials(ctx) body := new(bytes.Buffer) if err := json.NewEncoder(body).Encode(loginRequest{Email: username, Password: password}); err != nil { return fmt.Errorf("can not encode login body: %w", err) @@ -136,10 +140,45 @@ func (c *Client) Login(username, password string) error { return nil } +func (c *Client) doLoggedIn(req *http.Request, obj interface{}) error { + if c.http.Jar == nil { + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return fmt.Errorf("jar init error: %w", err) + } + + c.http.Jar = jar + } + + if !c.isLoggedIn() { + err := c.Login(req.Context()) + if err != nil { + return err + } + } + + return c.do(req, obj) +} + +func (c *Client) isLoggedIn() bool { + for _, cookie := range c.http.Jar.Cookies(c.baseURL) { + if cookie.Name != "sessionid" { + continue + } + + if cookie.Expires.Before(time.Now()) { + return false + } + + return true + } + + return false +} + func (c *Client) do(req *http.Request, obj interface{}) error { req.Header.Set("content-type", "application/json") - cookies := c.http.Jar.Cookies(c.baseURL) - for _, cookie := range cookies { + for _, cookie := range c.http.Jar.Cookies(c.baseURL) { if cookie.Name != "csrftoken" { continue } @@ -261,28 +300,28 @@ func newScheduleFrom(data io.Reader) (*Schedule, error) { return &Schedule{Calendar: cal}, nil } -func (c *Client) listTeams() (teamList, error) { +func (c *Client) listTeams(ctx context.Context) (teamList, error) { url := fmt.Sprintf("%s/api/account/teams", c.baseURL) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("can't create request to list teams: %w", err) } teamsResp := []team{} - if err := c.do(req, &teamsResp); err != nil { + if err := c.doLoggedIn(req, &teamsResp); err != nil { return nil, fmt.Errorf("can't list teams: %w", err) } return teamsResp, nil } -func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) { +func (c *Client) listSchedules(ctx context.Context, teamID string) ([]apiSchedule, error) { url := fmt.Sprintf("%s/api/account/teams/%s/schedules", c.baseURL, teamID) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("error creating request for listing team schedules: %w", err) } scheduleList := []apiSchedule{} - if err := c.do(req, &scheduleList); err != nil { + if err := c.doLoggedIn(req, &scheduleList); err != nil { return nil, fmt.Errorf("error when listing schedules of team with ID %s: %w", teamID, err) } return scheduleList, nil @@ -290,20 +329,20 @@ func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) { // CombinedSchedule returns the full combined schedule of all team schedules where the given // user email is part of -func (c *Client) CombinedSchedule(email string) (*Schedule, error) { - teams, err := c.listTeams() +func (c *Client) CombinedSchedule(ctx context.Context, email string) (*Schedule, error) { + teams, err := c.listTeams(ctx) if err != nil { return nil, err } teamsOfUser := teams.teamsForUser(email) combined := ics.NewCalendarFor("zenduty-oncall") for _, team := range teamsOfUser { - schedules, err := c.listSchedules(team.ID) + schedules, err := c.listSchedules(ctx, team.ID) if err != nil { return nil, err } for _, schedule := range schedules { - calendar, err := c.GetSchedule(team.ID, schedule.ID, amountMonths) + calendar, err := c.GetSchedule(ctx, team.ID, schedule.ID, amountMonths) if err != nil { return nil, err } @@ -323,14 +362,14 @@ func (c *Client) CombinedSchedule(email string) (*Schedule, error) { return &Schedule{Calendar: combined}, nil } -func (c *Client) GetSchedule(teamID, scheduleID string, months int) (*Schedule, error) { +func (c *Client) GetSchedule(ctx context.Context, teamID, scheduleID string, months int) (*Schedule, error) { url := fmt.Sprintf("%s/api/account/teams/%s/schedules/%s/get_schedule_ics/?months=%d&is_team_or_user=1", c.baseURL, teamID, scheduleID, months) - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("error creating request schedule %q of team %q: %w", scheduleID, teamID, err) } body := scheduleICSResponse{} - if err := c.do(req, &body); err != nil { + if err := c.doLoggedIn(req, &body); err != nil { return nil, fmt.Errorf("error requesting schedule %q of team %q: %w", scheduleID, teamID, err) } res, err := c.http.Get(body.URL) diff --git a/internal/zenduty/zenduty_test.go b/internal/zenduty/zenduty_test.go index 3a6061a..69d1edd 100644 --- a/internal/zenduty/zenduty_test.go +++ b/internal/zenduty/zenduty_test.go @@ -1,16 +1,19 @@ package zenduty import ( + "context" "os" "testing" "time" ics "github.com/arran4/golang-ical" - "github.com/stretchr/testify/require" + "github.com/matryer/is" "golang.org/x/exp/slog" ) func TestICS(t *testing.T) { + is := is.New(t) + username, password := os.Getenv("ZENDUTY_USERNAME"), os.Getenv("ZENDUTY_PASSWORD") if username == "" || password == "" { t.Skip("no ZENDUTY_USERNAME or ZENDUTY_PASSWORD env variable set") @@ -22,17 +25,22 @@ func TestICS(t *testing.T) { // one // options.Level = slog.LevelDebug logger := NewLogger(options) - z := NewClient(Logger(logger)) - require.NoError(t, z.Login(username, password)) + z := NewClient( + func(context.Context) (string, string) { return username, password }, + Logger(logger), + ) + is.NoErr(z.Login(context.Background())) - calendar, err := z.CombinedSchedule(username) - require.NoError(t, err) + calendar, err := z.CombinedSchedule(context.Background(), username) + is.NoErr(err) calendar = calendar.OnlyAttendees(username) - require.True(t, len(calendar.Events()) > 0) + is.True(len(calendar.Events()) > 0) logger.Debug("my schedule", "content", calendar.Serialize()) } func TestOnlyAttendees(t *testing.T) { + is := is.New(t) + for name, testCase := range map[string]struct { schedule *Schedule emails []string @@ -108,9 +116,9 @@ func TestOnlyAttendees(t *testing.T) { t.Run(name, func(t *testing.T) { testCase := testCase schedule := testCase.schedule.OnlyAttendees(testCase.emails...) - require.Len(t, schedule.Events(), len(testCase.expectedIDs)) + is.True(len(schedule.Events()) == len(testCase.expectedIDs)) for _, id := range testCase.expectedIDs { - require.True(t, schedule.ContainsEventID(id), "did not find ID %s in events", id) + is.True(schedule.ContainsEventID(id)) } }) } diff --git a/main.go b/main.go index 1ad2c6e..607b22b 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "net/http" @@ -32,8 +33,11 @@ func run(out io.Writer) error { loggerOpts := zenduty.LoggerOptions{} loggerOpts.Out = out logger := zenduty.NewLogger(loggerOpts) - z := zenduty.NewClient(zenduty.Logger(logger)) - if err := z.Login(username, password); err != nil { + z := zenduty.NewClient( + func(context.Context) (string, string) { return username, password }, + zenduty.Logger(logger), + ) + if err := z.Login(context.Background()); err != nil { return err } @@ -72,9 +76,9 @@ func run(out io.Writer) error { return server.ListenAndServe() } -func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle { +func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(ctx context.Context, teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - schedule, err := getSchedule(ps.ByName(teamKey), ps.ByName(scheduleKey), 12) + schedule, err := getSchedule(r.Context(), ps.ByName(teamKey), ps.ByName(scheduleKey), 12) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) @@ -86,7 +90,7 @@ func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(t func myScheduleHandler(c *zenduty.Client, forUser func(httprouter.Params) string) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - schedule, err := c.CombinedSchedule(forUser(ps)) + schedule, err := c.CombinedSchedule(r.Context(), forUser(ps)) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error()))