diff --git a/.env.example b/.env.example index 6f9a2956..43517575 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,4 @@ TOKEN_KEY= KNOQ_VERSION= KNOQ_REVISION= DEVELOPMENT= +TRAQ_ACCESS_TOKEN= diff --git a/.gitignore b/.gitignore index dd9b185d..3406127d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ debug # for development _development/* +.env diff --git a/README.md b/README.md index 482d90ee..2adf04e0 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ knoQ の全ての機能を動作させるためには、追加の情報が必要 | KNOQ_VERSION | 環境変数 | UNKNOWN | knoQ のバージョン (github actions でイメージ作成時に指定) | | KNOQ_REVISION | 環境変数 | UNKNOWN | git の sha1 (github actions でイメージ作成時に指定) | | DEVELOPMENT | 環境変数 | | 開発時かどうか | +| TRAQ_ACCESS_TOKEN | 環境変数 | | traQ へのアクセストークン | | service.json | ファイル | 空のファイル | google calendar api に必要(権限は必要なし) | ### テスト diff --git a/compose.yml b/compose.yml index ec1953d3..6b008310 100755 --- a/compose.yml +++ b/compose.yml @@ -25,6 +25,7 @@ services: KNOQ_VERSION: ${KNOQ_VERSION:-dev} DEVELOPMENT: true GORM_LOG_LEVEL: info + TRAQ_ACCESS_TOKEN: ports: - "${APP_PORT:-3000}:3000" depends_on: diff --git a/infra/traq/group.go b/infra/traq/group.go index c0ef2821..9b05945c 100644 --- a/infra/traq/group.go +++ b/infra/traq/group.go @@ -9,9 +9,10 @@ import ( "github.com/traPtitech/go-traq" ) -func (repo *TraQRepository) GetGroup(token *oauth2.Token, groupID uuid.UUID) (*traq.UserGroup, error) { - ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) +func (repo *TraQRepository) GetGroup(groupID uuid.UUID) (*traq.UserGroup, error) { + ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken) + apiClient := traq.NewAPIClient(traqAPIConfig) + // TODO: 一定期間キャッシュする group, resp, err := apiClient.GroupApi.GetUserGroup(ctx, groupID.String()).Execute() if err != nil { return nil, err @@ -23,9 +24,10 @@ func (repo *TraQRepository) GetGroup(token *oauth2.Token, groupID uuid.UUID) (*t return group, err } -func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]traq.UserGroup, error) { - ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) +func (repo *TraQRepository) GetAllGroups() ([]traq.UserGroup, error) { + ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken) + apiClient := traq.NewAPIClient(traqAPIConfig) + // TODO: 一定期間キャッシュする groups, resp, err := apiClient.GroupApi.GetUserGroups(ctx).Execute() if err != nil { return nil, err @@ -39,7 +41,7 @@ func (repo *TraQRepository) GetAllGroups(token *oauth2.Token) ([]traq.UserGroup, func (repo *TraQRepository) GetUserBelongingGroupIDs(token *oauth2.Token, userID uuid.UUID) ([]uuid.UUID, error) { ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) + apiClient := NewOauth2APIClient(ctx, token) user, resp, err := apiClient.UserApi.GetUser(ctx, userID.String()).Execute() if err != nil { return nil, err diff --git a/infra/traq/traq.go b/infra/traq/traq.go index 5162acd3..7c5a6405 100644 --- a/infra/traq/traq.go +++ b/infra/traq/traq.go @@ -14,8 +14,9 @@ import ( // TraQRepository is traq type TraQRepository struct { - Config *oauth2.Config - URL string + Config *oauth2.Config + URL string + ServerAccessToken string } var TraQDefaultConfig = &oauth2.Config{ @@ -29,6 +30,8 @@ var TraQDefaultConfig = &oauth2.Config{ }, } +var traqAPIConfig = traq.NewConfiguration() + func newPKCE() (pkceOptions []oauth2.AuthCodeOption, codeVerifier string) { codeVerifier = random.AlphaNumeric(43, true) result := sha256.Sum256([]byte(codeVerifier)) @@ -64,8 +67,8 @@ func (repo *TraQRepository) GetOAuthToken(query, state, codeVerifier string) (*o return repo.Config.Exchange(ctx, code, option) } -func NewAPIClient(ctx context.Context, token *oauth2.Token) *traq.APIClient { - traqconf := traq.NewConfiguration() +func NewOauth2APIClient(ctx context.Context, token *oauth2.Token) *traq.APIClient { + traqconf := traqAPIConfig conf := TraQDefaultConfig traqconf.HTTPClient = conf.Client(ctx, token) apiClient := traq.NewAPIClient(traqconf) diff --git a/infra/traq/user.go b/infra/traq/user.go index 32d4751c..f8c87e0f 100644 --- a/infra/traq/user.go +++ b/infra/traq/user.go @@ -8,9 +8,10 @@ import ( "golang.org/x/oauth2" ) -func (repo *TraQRepository) GetUser(token *oauth2.Token, userID uuid.UUID) (*traq.User, error) { - ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) +func (repo *TraQRepository) GetUser(userID uuid.UUID) (*traq.User, error) { + ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken) + apiClient := traq.NewAPIClient(traqAPIConfig) + // TODO: 一定期間キャッシュする userDetail, resp, err := apiClient.UserApi.GetUser(ctx, userID.String()).Execute() if err != nil { return nil, err @@ -31,9 +32,10 @@ func (repo *TraQRepository) GetUser(token *oauth2.Token, userID uuid.UUID) (*tra return &user, err } -func (repo *TraQRepository) GetUsers(token *oauth2.Token, includeSuspended bool) ([]traq.User, error) { - ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) +func (repo *TraQRepository) GetUsers(includeSuspended bool) ([]traq.User, error) { + ctx := context.WithValue(context.TODO(), traq.ContextAccessToken, repo.ServerAccessToken) + apiClient := traq.NewAPIClient(traqAPIConfig) + // TODO: 一定期間キャッシュする users, resp, err := apiClient.UserApi.GetUsers(ctx).IncludeSuspended(includeSuspended).Execute() if err != nil { return nil, err @@ -47,7 +49,7 @@ func (repo *TraQRepository) GetUsers(token *oauth2.Token, includeSuspended bool) func (repo *TraQRepository) GetUserMe(token *oauth2.Token) (*traq.User, error) { ctx := context.TODO() - apiClient := NewAPIClient(ctx, token) + apiClient := NewOauth2APIClient(ctx, token) userDetail, resp, err := apiClient.MeApi.GetMe(ctx).Execute() if err != nil { return nil, err diff --git a/main.go b/main.go index 0e722c0a..909572ee 100644 --- a/main.go +++ b/main.go @@ -43,6 +43,10 @@ var ( webhookSecret = getenv("WEBHOOK_SECRET", "") activityChannelID = getenv("ACTIVITY_CHANNEL_ID", "") dailyChannelID = getenv("DAILY_CHANNEL_ID", "") + + // TODO: traQにClient Credential Flowが実装されたら定期的に取得するように変更する + // Issue: https://github.com/traPtitech/traQ/issues/2403 + traqAccessToken = getenv("TRAQ_ACCESS_TOKEN", "") ) func main() { @@ -66,7 +70,8 @@ func main() { TokenURL: "https://q.trap.jp/api/v3/oauth2/token", }, }, - URL: "https://q.trap.jp/api/v3", + URL: "https://q.trap.jp/api/v3", + ServerAccessToken: traqAccessToken, } repo := &repository.Repository{ GormRepo: gormRepo, diff --git a/repository/group.go b/repository/group.go index 374870e9..18b1821e 100644 --- a/repository/group.go +++ b/repository/group.go @@ -74,11 +74,7 @@ func (repo *Repository) GetGroup(groupID uuid.UUID, info *domain.ConInfo) (*doma } // traq group - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return nil, defaultErrorHandling(err) - } - g, err := repo.TraQRepo.GetGroup(t, groupID) + g, err := repo.TraQRepo.GetGroup(groupID) if err != nil { return nil, defaultErrorHandling(err) } @@ -95,16 +91,12 @@ func (repo *Repository) GetGroup(groupID uuid.UUID, info *domain.ConInfo) (*doma func (repo *Repository) GetAllGroups(info *domain.ConInfo) ([]*domain.Group, error) { groups := make([]*domain.Group, 0) - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return nil, defaultErrorHandling(err) - } gg, err := repo.GormRepo.GetAllGroups() if err != nil { return nil, defaultErrorHandling(err) } groups = append(groups, db.ConvSPGroupToSPdomainGroup(gg)...) - tg, err := repo.TraQRepo.GetAllGroups(t) + tg, err := repo.TraQRepo.GetAllGroups() if err != nil { return nil, defaultErrorHandling(err) } @@ -195,12 +187,7 @@ func (repo *Repository) getTraPGroup(info *domain.ConInfo) *domain.Group { } func (repo *Repository) GetGradeGroupNames(info *domain.ConInfo) ([]string, error) { - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return nil, defaultErrorHandling(err) - } - - groups, err := repo.TraQRepo.GetAllGroups(t) + groups, err := repo.TraQRepo.GetAllGroups() if err != nil { return nil, defaultErrorHandling(err) } diff --git a/repository/user.go b/repository/user.go index 9aa4956c..7d2ac830 100644 --- a/repository/user.go +++ b/repository/user.go @@ -17,11 +17,7 @@ func (repo *Repository) SyncUsers(info *domain.ConInfo) error { if !repo.IsPrivilege(info) { return domain.ErrForbidden } - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return defaultErrorHandling(err) - } - traQUsers, err := repo.TraQRepo.GetUsers(t, true) + traQUsers, err := repo.TraQRepo.GetUsers(true) if err != nil { return defaultErrorHandling(err) } @@ -92,18 +88,13 @@ func (repo *Repository) LoginUser(query, state, codeVerifier string) (*domain.Us } func (repo *Repository) GetUser(userID uuid.UUID, info *domain.ConInfo) (*domain.User, error) { - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return nil, defaultErrorHandling(err) - } - userMeta, err := repo.GormRepo.GetUser(userID) if err != nil { return nil, defaultErrorHandling(err) } if userMeta.Provider.Issuer == traQIssuerName { - userBody, err := repo.TraQRepo.GetUser(t, userID) + userBody, err := repo.TraQRepo.GetUser(userID) if err != nil { return nil, defaultErrorHandling(err) } @@ -120,17 +111,12 @@ func (repo *Repository) GetUserMe(info *domain.ConInfo) (*domain.User, error) { } func (repo *Repository) GetAllUsers(includeSuspend, includeBot bool, info *domain.ConInfo) ([]*domain.User, error) { - t, err := repo.GormRepo.GetToken(info.ReqUserID) - if err != nil { - return nil, defaultErrorHandling(err) - } - userMetas, err := repo.GormRepo.GetAllUsers(!includeSuspend) if err != nil { return nil, defaultErrorHandling(err) } // TODO fix - traQUserBodys, err := repo.TraQRepo.GetUsers(t, includeSuspend) + traQUserBodys, err := repo.TraQRepo.GetUsers(includeSuspend) if err != nil { return nil, defaultErrorHandling(err) }