diff --git a/infra/traq/group.go b/infra/traq/group.go index 3e3d1938..c0ef2821 100644 --- a/infra/traq/group.go +++ b/infra/traq/group.go @@ -1,9 +1,7 @@ package traq import ( - "encoding/json" - "fmt" - "net/http" + "context" "github.com/gofrs/uuid" "golang.org/x/oauth2" @@ -12,51 +10,44 @@ import ( ) func (repo *TraQRepository) GetGroup(token *oauth2.Token, groupID uuid.UUID) (*traq.UserGroup, error) { - URL := fmt.Sprintf("%s/groups/%s", repo.URL, groupID) - req, err := http.NewRequest(http.MethodGet, URL, nil) + ctx := context.TODO() + apiClient := NewAPIClient(ctx, token) + group, resp, err := apiClient.GroupApi.GetUserGroup(ctx, groupID.String()).Execute() if err != nil { return nil, err } - data, err := repo.doRequest(token, req) + err = handleStatusCode(resp.StatusCode) if err != nil { return nil, err } - - group := new(traq.UserGroup) - err = json.Unmarshal(data, &group) return group, err } -func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]*traq.UserGroup, error) { - URL := fmt.Sprintf("%s/groups", repo.URL) - req, err := http.NewRequest(http.MethodGet, URL, nil) +func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]traq.UserGroup, error) { + ctx := context.TODO() + apiClient := NewAPIClient(ctx, token) + groups, resp, err := apiClient.GroupApi.GetUserGroups(ctx).Execute() if err != nil { return nil, err } - data, err := repo.doRequest(token, req) + err = handleStatusCode(resp.StatusCode) if err != nil { return nil, err } - - groups := make([]*traq.UserGroup, 0) - err = json.Unmarshal(data, &groups) return groups, err } func (repo *TraQRepository) GetUserBelongingGroupIDs(token *oauth2.Token, userID uuid.UUID) ([]uuid.UUID, error) { - URL := fmt.Sprintf("%s/users/%s", repo.URL, userID) - req, err := http.NewRequest(http.MethodGet, URL, nil) + ctx := context.TODO() + apiClient := NewAPIClient(ctx, token) + user, resp, err := apiClient.UserApi.GetUser(ctx, userID.String()).Execute() if err != nil { return nil, err } - data, err := repo.doRequest(token, req) + err = handleStatusCode(resp.StatusCode) if err != nil { return nil, err } - user := new(traq.UserDetail) - if err := json.Unmarshal(data, &user); err != nil { - return nil, err - } groups := make([]uuid.UUID, 0, len(user.Groups)) for _, gid := range user.Groups { groupUUID, err := uuid.FromString(gid) @@ -65,6 +56,5 @@ func (repo *TraQRepository) GetUserBelongingGroupIDs(token *oauth2.Token, userID } groups = append(groups, groupUUID) } - return groups, err } diff --git a/infra/traq/traq.go b/infra/traq/traq.go index adfe458d..5162acd3 100644 --- a/infra/traq/traq.go +++ b/infra/traq/traq.go @@ -5,8 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "errors" - "io" - "net/http" "net/url" "github.com/traPtitech/go-traq" @@ -66,19 +64,6 @@ func (repo *TraQRepository) GetOAuthToken(query, state, codeVerifier string) (*o return repo.Config.Exchange(ctx, code, option) } -func (repo *TraQRepository) doRequest(token *oauth2.Token, req *http.Request) ([]byte, error) { - client := repo.Config.Client(context.TODO(), token) - resp, err := client.Do(req) - if err != nil { - return nil, err - } - err = handleStatusCode(resp.StatusCode) - if err != nil { - return nil, err - } - return io.ReadAll(resp.Body) -} - func NewAPIClient(ctx context.Context, token *oauth2.Token) *traq.APIClient { traqconf := traq.NewConfiguration() conf := TraQDefaultConfig diff --git a/repository/converter.go b/repository/converter.go index 2f7f6fb4..5486d67c 100644 --- a/repository/converter.go +++ b/repository/converter.go @@ -14,13 +14,11 @@ func ConvPtraqUserToPdomainUser(src *traq.User) (dst *domain.User) { return } -func ConvSPtraqUserGroupToSPdomainGroup(src []*traq.UserGroup) (dst []*domain.Group) { +func ConvSPtraqUserGroupToSPdomainGroup(src []traq.UserGroup) (dst []*domain.Group) { dst = make([]*domain.Group, len(src)) for i := range src { - if src[i] != nil { - dst[i] = new(domain.Group) - (*dst[i]) = ConvtraqUserGroupTodomainGroup((*src[i])) - } + dst[i] = new(domain.Group) + (*dst[i]) = ConvtraqUserGroupTodomainGroup(src[i]) } return }