Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

feat: re-login if necessary #7

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
dist/
.env
8 changes: 2 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@ go 1.20
require (
github.com/arran4/golang-ical v0.1.0
github.com/julienschmidt/httprouter v1.3.0
github.com/stretchr/testify v1.8.4
github.com/matryer/is v1.4.1
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
golang.org/x/net v0.16.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require github.com/stretchr/testify v1.8.4 // indirect
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4d
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand All @@ -23,9 +23,7 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqR
golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos=
golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
89 changes: 64 additions & 25 deletions internal/zenduty/zenduty.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zenduty

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -24,9 +25,10 @@ const (
)

type Client struct {
http *http.Client
baseURL *url.URL
logger *slog.Logger
credentials func(context.Context) (username string, password string)
http *http.Client
baseURL *url.URL
logger *slog.Logger
}

type loginRequest struct {
Expand Down Expand Up @@ -71,11 +73,12 @@ func NewLogger(options LoggerOptions) *slog.Logger {

// NewClient returns a new zenduty client which can be modified by passing
// options
func NewClient(opts ...ClientOption) *Client {
func NewClient(credentials func(context.Context) (string, string), opts ...ClientOption) *Client {
c := &Client{
baseURL: defaultBaseURL(),
http: defaultHTTPClient(),
logger: NewLogger(LoggerOptions{}),
credentials: credentials,
baseURL: defaultBaseURL(),
http: defaultHTTPClient(),
logger: NewLogger(LoggerOptions{}),
}
for _, opt := range opts {
opt(c)
Expand All @@ -85,22 +88,22 @@ func NewClient(opts ...ClientOption) *Client {

type ClientOption func(c *Client)

// BaseURL sets the base URL of the zenduty client
// BaseURL sets the base URL of the Zenduty client
func BaseURL(u *url.URL) ClientOption {
return func(c *Client) {
c.baseURL = u
}
}

// Logger sets the logger of the zenduty client
// Logger sets the logger of the Zenduty client
func Logger(logger *slog.Logger) ClientOption {
return func(c *Client) {
c.logger = logger
}
}

// Login executes a login with the given username and password
func (c *Client) Login(username, password string) error {
func (c *Client) Login(ctx context.Context) error {
if c.http.Jar == nil {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
Expand All @@ -118,6 +121,7 @@ func (c *Client) Login(username, password string) error {
return fmt.Errorf("error getting login page")
}

username, password := c.credentials(ctx)
body := new(bytes.Buffer)
if err := json.NewEncoder(body).Encode(loginRequest{Email: username, Password: password}); err != nil {
return fmt.Errorf("can not encode login body: %w", err)
Expand All @@ -136,10 +140,45 @@ func (c *Client) Login(username, password string) error {
return nil
}

func (c *Client) doLoggedIn(req *http.Request, obj interface{}) error {
if c.http.Jar == nil {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
return fmt.Errorf("jar init error: %w", err)
}

c.http.Jar = jar
}

if !c.isLoggedIn() {
err := c.Login(req.Context())
if err != nil {
return err
}
}

return c.do(req, obj)
}

func (c *Client) isLoggedIn() bool {
for _, cookie := range c.http.Jar.Cookies(c.baseURL) {
if cookie.Name != "sessionid" {
continue
}

if cookie.Expires.Before(time.Now()) {
return false
}

return true
}

return false
}

func (c *Client) do(req *http.Request, obj interface{}) error {
req.Header.Set("content-type", "application/json")
cookies := c.http.Jar.Cookies(c.baseURL)
for _, cookie := range cookies {
for _, cookie := range c.http.Jar.Cookies(c.baseURL) {
if cookie.Name != "csrftoken" {
continue
}
Expand Down Expand Up @@ -261,49 +300,49 @@ func newScheduleFrom(data io.Reader) (*Schedule, error) {
return &Schedule{Calendar: cal}, nil
}

func (c *Client) listTeams() (teamList, error) {
func (c *Client) listTeams(ctx context.Context) (teamList, error) {
url := fmt.Sprintf("%s/api/account/teams", c.baseURL)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("can't create request to list teams: %w", err)
}
teamsResp := []team{}
if err := c.do(req, &teamsResp); err != nil {
if err := c.doLoggedIn(req, &teamsResp); err != nil {
return nil, fmt.Errorf("can't list teams: %w", err)
}
return teamsResp, nil
}

func (c *Client) listSchedules(teamID string) ([]apiSchedule, error) {
func (c *Client) listSchedules(ctx context.Context, teamID string) ([]apiSchedule, error) {
url := fmt.Sprintf("%s/api/account/teams/%s/schedules", c.baseURL, teamID)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request for listing team schedules: %w", err)
}

scheduleList := []apiSchedule{}
if err := c.do(req, &scheduleList); err != nil {
if err := c.doLoggedIn(req, &scheduleList); err != nil {
return nil, fmt.Errorf("error when listing schedules of team with ID %s: %w", teamID, err)
}
return scheduleList, nil
}

// CombinedSchedule returns the full combined schedule of all team schedules where the given
// user email is part of
func (c *Client) CombinedSchedule(email string) (*Schedule, error) {
teams, err := c.listTeams()
func (c *Client) CombinedSchedule(ctx context.Context, email string) (*Schedule, error) {
teams, err := c.listTeams(ctx)
if err != nil {
return nil, err
}
teamsOfUser := teams.teamsForUser(email)
combined := ics.NewCalendarFor("zenduty-oncall")
for _, team := range teamsOfUser {
schedules, err := c.listSchedules(team.ID)
schedules, err := c.listSchedules(ctx, team.ID)
if err != nil {
return nil, err
}
for _, schedule := range schedules {
calendar, err := c.GetSchedule(team.ID, schedule.ID, amountMonths)
calendar, err := c.GetSchedule(ctx, team.ID, schedule.ID, amountMonths)
if err != nil {
return nil, err
}
Expand All @@ -323,14 +362,14 @@ func (c *Client) CombinedSchedule(email string) (*Schedule, error) {
return &Schedule{Calendar: combined}, nil
}

func (c *Client) GetSchedule(teamID, scheduleID string, months int) (*Schedule, error) {
func (c *Client) GetSchedule(ctx context.Context, teamID, scheduleID string, months int) (*Schedule, error) {
url := fmt.Sprintf("%s/api/account/teams/%s/schedules/%s/get_schedule_ics/?months=%d&is_team_or_user=1", c.baseURL, teamID, scheduleID, months)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request schedule %q of team %q: %w", scheduleID, teamID, err)
}
body := scheduleICSResponse{}
if err := c.do(req, &body); err != nil {
if err := c.doLoggedIn(req, &body); err != nil {
return nil, fmt.Errorf("error requesting schedule %q of team %q: %w", scheduleID, teamID, err)
}
res, err := c.http.Get(body.URL)
Expand Down
24 changes: 16 additions & 8 deletions internal/zenduty/zenduty_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package zenduty

import (
"context"
"os"
"testing"
"time"

ics "github.com/arran4/golang-ical"
"github.com/stretchr/testify/require"
"github.com/matryer/is"
"golang.org/x/exp/slog"
)

func TestICS(t *testing.T) {
is := is.New(t)

username, password := os.Getenv("ZENDUTY_USERNAME"), os.Getenv("ZENDUTY_PASSWORD")
if username == "" || password == "" {
t.Skip("no ZENDUTY_USERNAME or ZENDUTY_PASSWORD env variable set")
Expand All @@ -22,17 +25,22 @@ func TestICS(t *testing.T) {
// one
// options.Level = slog.LevelDebug
logger := NewLogger(options)
z := NewClient(Logger(logger))
require.NoError(t, z.Login(username, password))
z := NewClient(
func(context.Context) (string, string) { return username, password },
Logger(logger),
)
is.NoErr(z.Login(context.Background()))

calendar, err := z.CombinedSchedule(username)
require.NoError(t, err)
calendar, err := z.CombinedSchedule(context.Background(), username)
is.NoErr(err)
calendar = calendar.OnlyAttendees(username)
require.True(t, len(calendar.Events()) > 0)
is.True(len(calendar.Events()) > 0)
logger.Debug("my schedule", "content", calendar.Serialize())
}

func TestOnlyAttendees(t *testing.T) {
is := is.New(t)

for name, testCase := range map[string]struct {
schedule *Schedule
emails []string
Expand Down Expand Up @@ -108,9 +116,9 @@ func TestOnlyAttendees(t *testing.T) {
t.Run(name, func(t *testing.T) {
testCase := testCase
schedule := testCase.schedule.OnlyAttendees(testCase.emails...)
require.Len(t, schedule.Events(), len(testCase.expectedIDs))
is.True(len(schedule.Events()) == len(testCase.expectedIDs))
for _, id := range testCase.expectedIDs {
require.True(t, schedule.ContainsEventID(id), "did not find ID %s in events", id)
is.True(schedule.ContainsEventID(id))
}
})
}
Expand Down
14 changes: 9 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -32,8 +33,11 @@ func run(out io.Writer) error {
loggerOpts := zenduty.LoggerOptions{}
loggerOpts.Out = out
logger := zenduty.NewLogger(loggerOpts)
z := zenduty.NewClient(zenduty.Logger(logger))
if err := z.Login(username, password); err != nil {
z := zenduty.NewClient(
func(context.Context) (string, string) { return username, password },
zenduty.Logger(logger),
)
if err := z.Login(context.Background()); err != nil {
return err
}

Expand Down Expand Up @@ -72,9 +76,9 @@ func run(out io.Writer) error {
return server.ListenAndServe()
}

func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle {
func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(ctx context.Context, teamID string, scheduleID string, months int) (*zenduty.Schedule, error)) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
schedule, err := getSchedule(ps.ByName(teamKey), ps.ByName(scheduleKey), 12)
schedule, err := getSchedule(r.Context(), ps.ByName(teamKey), ps.ByName(scheduleKey), 12)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
Expand All @@ -86,7 +90,7 @@ func byAtendeeHandler(teamKey, scheduleKey, memberKey string, getSchedule func(t

func myScheduleHandler(c *zenduty.Client, forUser func(httprouter.Params) string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
schedule, err := c.CombinedSchedule(forUser(ps))
schedule, err := c.CombinedSchedule(r.Context(), forUser(ps))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
Expand Down