diff --git a/.tbls.yml b/.tbls.yml index 522116d6c..4208bdb0a 100644 --- a/.tbls.yml +++ b/.tbls.yml @@ -301,9 +301,6 @@ comments: token: セッショントークン reference_id: 参照ID user_id: セッションがログインしているユーザーUUID - last_access: 最終アクセス日時 - last_ip: 最終アクセスIPアドレス - last_user_agent: 最終アクセスUserAgent data: セッションデータ(gobバイナリ) created: 生成日時 - table: oauth2_authorizes diff --git a/Makefile b/Makefile index ccfc848b2..0177fbfbf 100644 --- a/Makefile +++ b/Makefile @@ -45,17 +45,17 @@ db-gen-docs: @if [ -d "./docs/dbschema" ]; then \ rm -r ./docs/dbschema; \ fi - go run main.go migrate --reset --port $(TEST_DB_PORT) + TRAQ_MARIADB_PORT=$(TEST_DB_PORT) go run main.go migrate --reset TBLS_DSN="mysql://root:password@127.0.0.1:$(TEST_DB_PORT)/traq" tbls doc .PHONY: db-diff-docs db-diff-docs: - go run main.go migrate --reset --port $(TEST_DB_PORT) + TRAQ_MARIADB_PORT=$(TEST_DB_PORT) go run main.go migrate --reset TBLS_DSN="mysql://root:password@127.0.0.1:$(TEST_DB_PORT)/traq" tbls diff .PHONY: db-lint db-lint: - go run main.go migrate --reset --port $(TEST_DB_PORT) + TRAQ_MARIADB_PORT=$(TEST_DB_PORT) go run main.go migrate --reset TBLS_DSN="mysql://root:password@127.0.0.1:$(TEST_DB_PORT)/traq" tbls lint .PHONY: goreleaser-snapshot diff --git a/cmd/file.go b/cmd/file.go index fc022f71a..e89f6ef54 100644 --- a/cmd/file.go +++ b/cmd/file.go @@ -20,18 +20,6 @@ func fileCommand() *cobra.Command { filePruneCommand(), ) - flags := cmd.PersistentFlags() - flags.String("host", "", "database host") - bindPFlag(flags, "mariadb.host", "host") - flags.Int("port", 0, "database port") - bindPFlag(flags, "mariadb.port", "port") - flags.String("name", "", "database name") - bindPFlag(flags, "mariadb.database", "name") - flags.String("user", "", "database user") - bindPFlag(flags, "mariadb.username", "user") - flags.String("pass", "", "database password") - bindPFlag(flags, "mariadb.password", "pass") - return &cmd } diff --git a/cmd/migrate.go b/cmd/migrate.go index fb1615ac2..76615ffbf 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -28,16 +28,6 @@ func migrateCommand() *cobra.Command { } flags := cmd.Flags() - flags.String("host", "", "database host") - bindPFlag(flags, "mariadb.host", "host") - flags.Int("port", 0, "database port") - bindPFlag(flags, "mariadb.port", "port") - flags.String("name", "", "database name") - bindPFlag(flags, "mariadb.database", "name") - flags.String("user", "", "database user") - bindPFlag(flags, "mariadb.username", "user") - flags.String("pass", "", "database password") - bindPFlag(flags, "mariadb.password", "pass") flags.BoolVar(&dropDB, "reset", false, "whether to truncate database (drop all tables)") return &cmd diff --git a/cmd/migrate_v2_to_v3.go b/cmd/migrate_v2_to_v3.go index 223dea0ca..f486a53e1 100644 --- a/cmd/migrate_v2_to_v3.go +++ b/cmd/migrate_v2_to_v3.go @@ -63,18 +63,6 @@ func migrateV2ToV3Command() *cobra.Command { } flags := cmd.Flags() - flags.String("host", "", "database host") - bindPFlag(flags, "mariadb.host", "host") - flags.Int("port", 0, "database port") - bindPFlag(flags, "mariadb.port", "port") - flags.String("name", "", "database name") - bindPFlag(flags, "mariadb.database", "name") - flags.String("user", "", "database user") - bindPFlag(flags, "mariadb.username", "user") - flags.String("pass", "", "database password") - bindPFlag(flags, "mariadb.password", "pass") - flags.String("origin", "", "traQ origin") - bindPFlag(flags, "origin", "origin") flags.BoolVar(&dryRun, "dry-run", false, "dry run") flags.BoolVar(&skipConvertMessage, "skip-convert-message", false, "skip message converting") flags.IntVar(&startMessagePage, "start-message-page", 0, "start message page (zero-origin)") diff --git a/cmd/serve.go b/cmd/serve.go index a6ad74c77..303ad171e 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -8,7 +8,6 @@ import ( "github.com/leandro-lugaresi/hub" "github.com/spf13/cobra" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" "github.com/traPtitech/traQ/service" "github.com/traPtitech/traQ/utils/gormzap" "github.com/traPtitech/traQ/utils/jwt" @@ -98,13 +97,6 @@ func serveCommand() *cobra.Command { logger.Info("data initialization finished") } - // SessionStore - sessionStore, err := sessions.NewGORMStore(engine) - if err != nil { - logger.Fatal("failed to setup session store", zap.Error(err)) - } - sessions.SetStore(sessionStore) - // JWT for QRCode if priv := c.JWT.Keys.Private; priv != "" { privRaw, err := ioutil.ReadFile(priv) @@ -141,7 +133,6 @@ func serveCommand() *cobra.Command { if err := server.Shutdown(ctx); err != nil { logger.Warn("abnormal shutdown", zap.Error(err)) } - sessions.PurgeCache() logger.Info("traQ shutdown") }, } diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 0760ce2f5..86b591a85 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -81,7 +81,7 @@ func newServer(hub2 *hub.Hub, db *gorm.DB, repo repository.Repository, logger *z WS: wsStreamer, } routerConfig := providerRouterConfig(c2) - echo := router.Setup(hub2, repo, services, logger, routerConfig) + echo := router.Setup(hub2, db, repo, services, logger, routerConfig) server := &Server{ L: logger, SS: services, diff --git a/docs/dbschema/README.md b/docs/dbschema/README.md index 7a3eb29bb..88d4022d6 100644 --- a/docs/dbschema/README.md +++ b/docs/dbschema/README.md @@ -26,7 +26,7 @@ | [oauth2_clients](oauth2_clients.md) | 11 | OAuth2クライアントテーブル | BASE TABLE | | [oauth2_tokens](oauth2_tokens.md) | 11 | OAuth2トークンテーブル | BASE TABLE | | [pins](pins.md) | 4 | ピンテーブル | BASE TABLE | -| [r_sessions](r_sessions.md) | 8 | traQ API HTTPセッションテーブル | BASE TABLE | +| [r_sessions](r_sessions.md) | 5 | traQ API HTTPセッションテーブル | BASE TABLE | | [stamp_palettes](stamp_palettes.md) | 7 | スタンプパレットテーブル | BASE TABLE | | [stamps](stamps.md) | 8 | スタンプテーブル | BASE TABLE | | [stars](stars.md) | 2 | お気に入りチャンネルテーブル | BASE TABLE | diff --git a/docs/dbschema/channels.md b/docs/dbschema/channels.md index 1c04d9817..f2e402199 100644 --- a/docs/dbschema/channels.md +++ b/docs/dbschema/channels.md @@ -22,7 +22,8 @@ CREATE TABLE `channels` ( `updated_at` datetime(6) DEFAULT NULL, `deleted_at` datetime(6) DEFAULT NULL, PRIMARY KEY (`id`), - UNIQUE KEY `name_parent` (`name`,`parent_id`) + UNIQUE KEY `name_parent` (`name`,`parent_id`), + KEY `idx_channel_channels_id_is_public_is_forced` (`id`,`is_public`,`is_forced`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ``` @@ -56,6 +57,7 @@ CREATE TABLE `channels` ( | Name | Definition | | ---- | ---------- | +| idx_channel_channels_id_is_public_is_forced | KEY idx_channel_channels_id_is_public_is_forced (id, is_public, is_forced) USING BTREE | | PRIMARY | PRIMARY KEY (id) USING BTREE | | name_parent | UNIQUE KEY name_parent (name, parent_id) USING BTREE | diff --git a/docs/dbschema/messages.md b/docs/dbschema/messages.md index 103f04f9c..7989b5d62 100644 --- a/docs/dbschema/messages.md +++ b/docs/dbschema/messages.md @@ -21,6 +21,7 @@ CREATE TABLE `messages` ( KEY `idx_messages_created_at` (`created_at`), KEY `messages_user_id_users_id_foreign` (`user_id`), KEY `idx_messages_channel_id_deleted_at_created_at` (`channel_id`,`deleted_at`,`created_at`), + KEY `idx_messages_deleted_at_created_at` (`deleted_at`,`created_at`), CONSTRAINT `messages_channel_id_channels_id_foreign` FOREIGN KEY (`channel_id`) REFERENCES `channels` (`id`) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT `messages_user_id_users_id_foreign` FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON DELETE CASCADE ON UPDATE CASCADE ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 @@ -55,6 +56,7 @@ CREATE TABLE `messages` ( | idx_messages_channel_id | KEY idx_messages_channel_id (channel_id) USING BTREE | | idx_messages_channel_id_deleted_at_created_at | KEY idx_messages_channel_id_deleted_at_created_at (channel_id, deleted_at, created_at) USING BTREE | | idx_messages_created_at | KEY idx_messages_created_at (created_at) USING BTREE | +| idx_messages_deleted_at_created_at | KEY idx_messages_deleted_at_created_at (deleted_at, created_at) USING BTREE | | messages_user_id_users_id_foreign | KEY messages_user_id_users_id_foreign (user_id) USING BTREE | | PRIMARY | PRIMARY KEY (id) USING BTREE | diff --git a/docs/dbschema/r_sessions.md b/docs/dbschema/r_sessions.md index cdfd54fad..82e78b70f 100644 --- a/docs/dbschema/r_sessions.md +++ b/docs/dbschema/r_sessions.md @@ -12,9 +12,6 @@ CREATE TABLE `r_sessions` ( `token` varchar(50) NOT NULL DEFAULT '', `reference_id` char(36) DEFAULT NULL, `user_id` varchar(36) DEFAULT NULL, - `last_access` datetime(6) DEFAULT NULL, - `last_ip` text, - `last_user_agent` text, `data` longblob, `created` datetime(6) DEFAULT NULL, PRIMARY KEY (`token`), @@ -32,9 +29,6 @@ CREATE TABLE `r_sessions` ( | token | varchar(50) | | false | | | セッショントークン | | reference_id | char(36) | | true | | | 参照ID | | user_id | varchar(36) | | true | | | セッションがログインしているユーザーUUID | -| last_access | datetime(6) | | true | | | 最終アクセス日時 | -| last_ip | text | | true | | | 最終アクセスIPアドレス | -| last_user_agent | text | | true | | | 最終アクセスUserAgent | | data | longblob | | true | | | セッションデータ(gobバイナリ) | | created | datetime(6) | | true | | | 生成日時 | diff --git a/docs/dbschema/r_sessions.svg b/docs/dbschema/r_sessions.svg index ca5a78f15..c0fedfc1a 100644 --- a/docs/dbschema/r_sessions.svg +++ b/docs/dbschema/r_sessions.svg @@ -4,44 +4,35 @@ - - + + r_sessions - + r_sessions - - -r_sessions - -[BASE TABLE] - -token -[varchar(50)] - -reference_id -[char(36)] - -user_id -[varchar(36)] + + +r_sessions + +[BASE TABLE] -last_access -[datetime(6)] +token +[varchar(50)] -last_ip -[text] +reference_id +[char(36)] -last_user_agent -[text] +user_id +[varchar(36)] data [longblob] created [datetime(6)] - + diff --git a/docs/dbschema/schema.svg b/docs/dbschema/schema.svg index e135f1d85..f481512f8 100644 --- a/docs/dbschema/schema.svg +++ b/docs/dbschema/schema.svg @@ -875,35 +875,26 @@ r_sessions - - -r_sessions - -[BASE TABLE] - -token -[varchar(50)] - -reference_id -[char(36)] - -user_id -[varchar(36)] - -last_access -[datetime(6)] - -last_ip -[text] - -last_user_agent -[text] - -data -[longblob] - -created -[datetime(6)] + + +r_sessions + +[BASE TABLE] + +token +[varchar(50)] + +reference_id +[char(36)] + +user_id +[varchar(36)] + +data +[longblob] + +created +[datetime(6)] diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 73fc65afa..15aa42bdd 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2959,13 +2959,6 @@ components: id: type: string format: uuid - lastIP: - type: string - lastUserAgent: - type: string - lastAccess: - type: string - format: date-time createdAt: type: string format: date-time diff --git a/docs/v3-api.yaml b/docs/v3-api.yaml index 229a600cd..401620d62 100644 --- a/docs/v3-api.yaml +++ b/docs/v3-api.yaml @@ -4724,26 +4724,12 @@ components: type: string description: セッションUUID format: uuid - ip: - type: string - description: 最終アクセスIPアドレス - format: ipv4 - ua: - type: string - description: 最終アクセスユーザーエージェント - lastAccess: - type: string - format: date-time - description: 最終アクセス日時 issuedAt: type: string description: 発行日時 format: date-time required: - id - - ip - - ua - - lastAccess - issuedAt ActiveOAuth2Token: title: ActiveOAuth2Token diff --git a/event/topic.go b/event/topic.go index ce268d547..885445219 100644 --- a/event/topic.go +++ b/event/topic.go @@ -103,13 +103,11 @@ const ( // Fields: // message_id: uuid.UUID // channel_id: uuid.UUID - // pin_id: uuid.UUID MessagePinned = "message.pinned" // MessageUnpinned メッセージがピンから外れた // Fields: // message_id: uuid.UUID // channel_id: uuid.UUID - // pin_id: uuid.UUID MessageUnpinned = "message.unpinned" // MessageCited メッセージが引用された // Fields: diff --git a/migration/current.go b/migration/current.go index d12374fa4..06a676e92 100644 --- a/migration/current.go +++ b/migration/current.go @@ -2,7 +2,6 @@ package migration import ( "github.com/traPtitech/traQ/model" - "github.com/traPtitech/traQ/router/sessions" "gopkg.in/gormigrate.v1" ) @@ -29,6 +28,7 @@ func Migrations() []*gormigrate.Migration { v16(), // パーミッション修正 v17(), // ユーザーホームチャンネル v18(), // インデックス追加 + v19(), // httpセッション管理テーブル変更 } } @@ -75,7 +75,7 @@ func AllTables() []interface{} { &model.Channel{}, &model.ClipFolder{}, &model.User{}, - &sessions.SessionRecord{}, + &model.SessionRecord{}, } } diff --git a/migration/v19.go b/migration/v19.go new file mode 100644 index 000000000..e310d00b6 --- /dev/null +++ b/migration/v19.go @@ -0,0 +1,42 @@ +package migration + +import ( + "github.com/gofrs/uuid" + "github.com/jinzhu/gorm" + "gopkg.in/gormigrate.v1" + "time" +) + +// v19 httpセッション管理テーブル変更 +func v19() *gormigrate.Migration { + return &gormigrate.Migration{ + ID: "19", + Migrate: func(db *gorm.DB) error { + if err := db.Table(v19OldSessionRecord{}.TableName()).DropColumn("last_access").Error; err != nil { + return err + } + if err := db.Table(v19OldSessionRecord{}.TableName()).DropColumn("last_ip").Error; err != nil { + return err + } + if err := db.Table(v19OldSessionRecord{}.TableName()).DropColumn("last_user_agent").Error; err != nil { + return err + } + return nil + }, + } +} + +type v19OldSessionRecord struct { + Token string `gorm:"type:varchar(50);primary_key"` + ReferenceID uuid.UUID `gorm:"type:char(36);unique"` + UserID uuid.UUID `gorm:"type:varchar(36);index"` + LastAccess time.Time `gorm:"precision:6"` + LastIP string `gorm:"type:text"` + LastUserAgent string `gorm:"type:text"` + Data []byte `gorm:"type:longblob"` + Created time.Time `gorm:"precision:6"` +} + +func (v19OldSessionRecord) TableName() string { + return "r_sessions" +} diff --git a/model/session.go b/model/session.go new file mode 100644 index 000000000..2b57bde5a --- /dev/null +++ b/model/session.go @@ -0,0 +1,34 @@ +package model + +import ( + "bytes" + "encoding/gob" + "github.com/gofrs/uuid" + "time" +) + +// SessionRecord GORM用Session構造体 +type SessionRecord struct { + Token string `gorm:"type:varchar(50);primary_key"` + ReferenceID uuid.UUID `gorm:"type:char(36);unique"` + UserID uuid.UUID `gorm:"type:varchar(36);index"` + Data []byte `gorm:"type:longblob"` + Created time.Time `gorm:"precision:6"` +} + +// TableName SessionRecordのテーブル名 +func (*SessionRecord) TableName() string { + return "r_sessions" +} + +func (sr *SessionRecord) SetData(data map[string]interface{}) { + var b bytes.Buffer + if err := gob.NewEncoder(&b).Encode(data); err != nil { + panic(err) // gobにdataの中身の構造体が登録されていない + } + sr.Data = b.Bytes() +} + +func (sr *SessionRecord) GetData() (data map[string]interface{}, err error) { + return data, gob.NewDecoder(bytes.NewReader(sr.Data)).Decode(&data) +} diff --git a/repository/stamp_impl.go b/repository/stamp_impl.go index 346305b24..d2f5a785f 100644 --- a/repository/stamp_impl.go +++ b/repository/stamp_impl.go @@ -9,29 +9,25 @@ import ( "github.com/traPtitech/traQ/event" "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/utils/validator" - "strings" "sync" "time" ) type stampRepository struct { - stamps map[uuid.UUID]*model.Stamp - nameStampsMap map[string]*model.Stamp - allJSON []byte - json []byte - updatedAt time.Time + stamps map[uuid.UUID]*model.Stamp + allJSON []byte + json []byte + updatedAt time.Time sync.RWMutex } func makeStampRepository(stamps []*model.Stamp) *stampRepository { r := &stampRepository{ - stamps: make(map[uuid.UUID]*model.Stamp, len(stamps)), - nameStampsMap: make(map[string]*model.Stamp, len(stamps)), - updatedAt: time.Now(), + stamps: make(map[uuid.UUID]*model.Stamp, len(stamps)), + updatedAt: time.Now(), } for _, s := range stamps { r.stamps[s.ID] = s - r.nameStampsMap[strings.ToLower(s.Name)] = s } r.regenerateJSON() @@ -40,32 +36,18 @@ func makeStampRepository(stamps []*model.Stamp) *stampRepository { func (r *stampRepository) add(s *model.Stamp) { r.stamps[s.ID] = s - r.nameStampsMap[strings.ToLower(s.Name)] = s r.updatedAt = time.Now() r.regenerateJSON() } func (r *stampRepository) update(s *model.Stamp) { - orig, ok := r.stamps[s.ID] - if !ok { - panic("assert !ok = false") - } - - delete(r.nameStampsMap, strings.ToLower(orig.Name)) r.stamps[s.ID] = s - r.nameStampsMap[strings.ToLower(s.Name)] = s r.updatedAt = time.Now() r.regenerateJSON() } func (r *stampRepository) delete(id uuid.UUID) { - s, ok := r.stamps[id] - if !ok { - panic("assert !ok = false") - } - delete(r.stamps, id) - delete(r.nameStampsMap, strings.ToLower(s.Name)) r.updatedAt = time.Now() r.regenerateJSON() } @@ -100,13 +82,6 @@ func (r *stampRepository) GetStamp(id uuid.UUID) (s *model.Stamp, ok bool) { return } -func (r *stampRepository) GetStampByName(name string) (s *model.Stamp, ok bool) { - r.RLock() - defer r.RUnlock() - s, ok = r.nameStampsMap[strings.ToLower(name)] - return -} - func (r *stampRepository) CheckIDs(ids []uuid.UUID) bool { r.RLock() defer r.RUnlock() diff --git a/router/auth/github.go b/router/auth/github.go index 514e33a1d..6bb74df3b 100644 --- a/router/auth/github.go +++ b/router/auth/github.go @@ -6,6 +6,7 @@ import ( json "github.com/json-iterator/go" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/session" "go.uber.org/zap" "golang.org/x/exp/utf8string" "golang.org/x/oauth2" @@ -22,10 +23,11 @@ const ( ) type GithubProvider struct { - config GithubProviderConfig - repo repository.Repository - logger *zap.Logger - oa2 oauth2.Config + config GithubProviderConfig + repo repository.Repository + logger *zap.Logger + sessStore session.Store + oa2 oauth2.Config } type GithubProviderConfig struct { @@ -98,11 +100,12 @@ func (u *githubUserInfo) IsLoginAllowedUser() bool { return true // TODO } -func NewGithubProvider(repo repository.Repository, logger *zap.Logger, config GithubProviderConfig) *GithubProvider { +func NewGithubProvider(repo repository.Repository, logger *zap.Logger, sessStore session.Store, config GithubProviderConfig) *GithubProvider { return &GithubProvider{ - repo: repo, - config: config, - logger: logger, + repo: repo, + config: config, + logger: logger, + sessStore: sessStore, oa2: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -113,11 +116,11 @@ func NewGithubProvider(repo repository.Repository, logger *zap.Logger, config Gi } func (p *GithubProvider) LoginHandler(c echo.Context) error { - return defaultLoginHandler(&p.oa2)(c) + return defaultLoginHandler(p.sessStore, &p.oa2)(c) } func (p *GithubProvider) CallbackHandler(c echo.Context) error { - return defaultCallbackHandler(p, &p.oa2, p.repo, p.config.RegisterUserIfNotFound)(c) + return defaultCallbackHandler(p, &p.oa2, p.repo, p.sessStore, p.config.RegisterUserIfNotFound)(c) } func (p *GithubProvider) FetchUserInfo(t *oauth2.Token) (UserInfo, error) { diff --git a/router/auth/google.go b/router/auth/google.go index 7b3cef457..c89262caa 100644 --- a/router/auth/google.go +++ b/router/auth/google.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/session" "go.uber.org/zap" "golang.org/x/exp/utf8string" "golang.org/x/oauth2" @@ -22,10 +23,11 @@ const ( ) type GoogleProvider struct { - config GoogleProviderConfig - repo repository.Repository - logger *zap.Logger - oa2 oauth2.Config + config GoogleProviderConfig + repo repository.Repository + logger *zap.Logger + sessStore session.Store + oa2 oauth2.Config } type GoogleProviderConfig struct { @@ -96,11 +98,12 @@ func (u *googleUserInfo) IsLoginAllowedUser() bool { return true // TODO } -func NewGoogleProvider(repo repository.Repository, logger *zap.Logger, config GoogleProviderConfig) *GoogleProvider { +func NewGoogleProvider(repo repository.Repository, logger *zap.Logger, sessStore session.Store, config GoogleProviderConfig) *GoogleProvider { return &GoogleProvider{ - repo: repo, - config: config, - logger: logger, + repo: repo, + config: config, + logger: logger, + sessStore: sessStore, oa2: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -112,11 +115,11 @@ func NewGoogleProvider(repo repository.Repository, logger *zap.Logger, config Go } func (p *GoogleProvider) LoginHandler(c echo.Context) error { - return defaultLoginHandler(&p.oa2)(c) + return defaultLoginHandler(p.sessStore, &p.oa2)(c) } func (p *GoogleProvider) CallbackHandler(c echo.Context) error { - return defaultCallbackHandler(p, &p.oa2, p.repo, p.config.RegisterUserIfNotFound)(c) + return defaultCallbackHandler(p, &p.oa2, p.repo, p.sessStore, p.config.RegisterUserIfNotFound)(c) } func (p *GoogleProvider) FetchUserInfo(t *oauth2.Token) (UserInfo, error) { diff --git a/router/auth/oidc.go b/router/auth/oidc.go index 1e55d2a9e..d2aafd7d7 100644 --- a/router/auth/oidc.go +++ b/router/auth/oidc.go @@ -7,6 +7,7 @@ import ( "github.com/coreos/go-oidc" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/utils/optional" "go.uber.org/zap" "golang.org/x/exp/utf8string" @@ -23,11 +24,12 @@ const ( ) type OIDCProvider struct { - config OIDCProviderConfig - repo repository.Repository - logger *zap.Logger - oa2 oauth2.Config - oidc *oidc.Provider + config OIDCProviderConfig + repo repository.Repository + logger *zap.Logger + oa2 oauth2.Config + sessStore session.Store + oidc *oidc.Provider } type OIDCProviderConfig struct { @@ -105,17 +107,18 @@ func (u *oidcUserInfo) IsLoginAllowedUser() bool { return true // TODO } -func NewOIDCProvider(repo repository.Repository, logger *zap.Logger, config OIDCProviderConfig) (*OIDCProvider, error) { +func NewOIDCProvider(repo repository.Repository, logger *zap.Logger, sessStore session.Store, config OIDCProviderConfig) (*OIDCProvider, error) { p, err := oidc.NewProvider(context.Background(), config.Issuer) if err != nil { return nil, err } return &OIDCProvider{ - repo: repo, - config: config, - logger: logger, - oidc: p, + repo: repo, + config: config, + logger: logger, + sessStore: sessStore, + oidc: p, oa2: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -127,11 +130,11 @@ func NewOIDCProvider(repo repository.Repository, logger *zap.Logger, config OIDC } func (p *OIDCProvider) LoginHandler(c echo.Context) error { - return defaultLoginHandler(&p.oa2)(c) + return defaultLoginHandler(p.sessStore, &p.oa2)(c) } func (p *OIDCProvider) CallbackHandler(c echo.Context) error { - return defaultCallbackHandler(p, &p.oa2, p.repo, p.config.RegisterUserIfNotFound)(c) + return defaultCallbackHandler(p, &p.oa2, p.repo, p.sessStore, p.config.RegisterUserIfNotFound)(c) } func (p *OIDCProvider) FetchUserInfo(t *oauth2.Token) (UserInfo, error) { diff --git a/router/auth/provider.go b/router/auth/provider.go index 61ffc42dc..d559e7072 100644 --- a/router/auth/provider.go +++ b/router/auth/provider.go @@ -10,7 +10,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/service/rbac/role" "github.com/traPtitech/traQ/utils/optional" "github.com/traPtitech/traQ/utils/random" @@ -45,20 +45,20 @@ type UserInfo interface { IsLoginAllowedUser() bool } -func defaultLoginHandler(oac *oauth2.Config) echo.HandlerFunc { +func defaultLoginHandler(sessStore session.Store, oac *oauth2.Config) echo.HandlerFunc { return func(c echo.Context) error { if len(c.Request().Header.Get(echo.HeaderAuthorization)) > 0 { return herror.BadRequest("Authorization Header must not be set.") } - sess, err := sessions.Get(c.Response(), c.Request(), false) + sess, err := sessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } if isTrue(c.QueryParam("link")) { // アカウント関連付けモード - if sess == nil || sess.GetUserID() == uuid.Nil { + if sess == nil || sess.UserID() == uuid.Nil { return herror.Unauthorized("You are not logged in. Please login.") } if err := sess.Set(accountLinkingFlag, true); err != nil { @@ -66,11 +66,8 @@ func defaultLoginHandler(oac *oauth2.Config) echo.HandlerFunc { } } else { // ログインモード - if sess != nil { - if sess.GetUserID() != uuid.Nil { - return herror.BadRequest("You have already logged in. Please logout once.") - } - _ = sess.Destroy(c.Response(), c.Request()) + if sess != nil && sess.UserID() != uuid.Nil { + return herror.BadRequest("You have already logged in. Please logout once.") } } @@ -87,7 +84,7 @@ func defaultLoginHandler(oac *oauth2.Config) echo.HandlerFunc { } } -func defaultCallbackHandler(p Provider, oac *oauth2.Config, repo repository.Repository, allowSignUp bool) echo.HandlerFunc { +func defaultCallbackHandler(p Provider, oac *oauth2.Config, repo repository.Repository, sessStore session.Store, allowSignUp bool) echo.HandlerFunc { return func(c echo.Context) error { if len(c.Request().Header.Get(echo.HeaderAuthorization)) > 0 { return herror.BadRequest("Authorization Header must not be set.") @@ -121,55 +118,59 @@ func defaultCallbackHandler(p Provider, oac *oauth2.Config, repo repository.Repo return c.String(http.StatusForbidden, "You are not permitted to access traQ") } - sess, err := sessions.Get(c.Response(), c.Request(), true) + sess, err := sessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } - - if sess.Get(accountLinkingFlag) != nil { - // アカウント関連付けモード - - _ = sess.Delete(accountLinkingFlag) - if sess.GetUserID() == uuid.Nil { - return herror.Unauthorized("You are not logged in. Please login.") - } - - // ユーザーアカウント状態を確認 - user, err := repo.GetUser(sess.GetUserID(), false) - if err != nil { + if sess != nil { + if v, err := sess.Get(accountLinkingFlag); err != nil { return herror.InternalServerError(err) - } - if !user.IsActive() { - return herror.Forbidden("this account is currently suspended") - } + } else if v == true { + // アカウント関連付けモード + if err := sess.Delete(accountLinkingFlag); err != nil { + return herror.InternalServerError(err) + } + if sess.UserID() == uuid.Nil { + return herror.Unauthorized("You are not logged in. Please login.") + } - // アカウントにリンク - if err := repo.LinkExternalUserAccount(user.GetID(), repository.LinkExternalUserAccountArgs{ - ProviderName: tu.GetProviderName(), - ExternalID: tu.GetID(), - Extra: model.JSON{"externalName": tu.GetRawName()}, - }); err != nil { - switch err { - case repository.ErrAlreadyExists: - return herror.BadRequest("this account has already been linked") - default: + // ユーザーアカウント状態を確認 + user, err := repo.GetUser(sess.UserID(), false) + if err != nil { return herror.InternalServerError(err) } - } - p.L().Info("an external user account has been linked to traQ user", - zap.Stringer("id", user.GetID()), - zap.String("name", user.GetName()), - zap.String("providerName", tu.GetProviderName()), - zap.String("externalId", tu.GetID()), - zap.String("externalName", tu.GetRawName())) + if !user.IsActive() { + return herror.Forbidden("this account is currently suspended") + } - return c.Redirect(http.StatusFound, "/") // TODO リダイレクト先を設定画面に + // アカウントにリンク + if err := repo.LinkExternalUserAccount(user.GetID(), repository.LinkExternalUserAccountArgs{ + ProviderName: tu.GetProviderName(), + ExternalID: tu.GetID(), + Extra: model.JSON{"externalName": tu.GetRawName()}, + }); err != nil { + switch err { + case repository.ErrAlreadyExists: + return herror.BadRequest("this account has already been linked") + default: + return herror.InternalServerError(err) + } + } + p.L().Info("an external user account has been linked to traQ user", + zap.Stringer("id", user.GetID()), + zap.String("name", user.GetName()), + zap.String("providerName", tu.GetProviderName()), + zap.String("externalId", tu.GetID()), + zap.String("externalName", tu.GetRawName())) + + return c.Redirect(http.StatusFound, "/") // TODO リダイレクト先を設定画面に + } } // ログインモード // ログインしていないことを確認 - if sess.GetUserID() != uuid.Nil { + if sess != nil && sess.UserID() != uuid.Nil { return herror.BadRequest("You have already logged in. Please logout once.") } @@ -221,7 +222,7 @@ func defaultCallbackHandler(p Provider, oac *oauth2.Config, repo repository.Repo return herror.Forbidden("this account is currently suspended") } - if err := sess.SetUser(user.GetID()); err != nil { + if _, err := sessStore.RenewSession(c, user.GetID()); err != nil { return herror.InternalServerError(err) } p.L().Info("User was logged in by external auth", diff --git a/router/auth/traq.go b/router/auth/traq.go index 6e78fc6ba..e5964e276 100644 --- a/router/auth/traq.go +++ b/router/auth/traq.go @@ -6,6 +6,7 @@ import ( json "github.com/json-iterator/go" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/session" "go.uber.org/zap" "golang.org/x/oauth2" "io/ioutil" @@ -18,10 +19,11 @@ const ( ) type TraQProvider struct { - config TraQProviderConfig - repo repository.Repository - logger *zap.Logger - oa2 oauth2.Config + config TraQProviderConfig + repo repository.Repository + logger *zap.Logger + sessStore session.Store + oa2 oauth2.Config } type TraQProviderConfig struct { @@ -85,11 +87,12 @@ func (u *traqUserInfo) IsLoginAllowedUser() bool { return true // TODO } -func NewTraQProvider(repo repository.Repository, logger *zap.Logger, config TraQProviderConfig) *TraQProvider { +func NewTraQProvider(repo repository.Repository, logger *zap.Logger, sessStore session.Store, config TraQProviderConfig) *TraQProvider { return &TraQProvider{ - repo: repo, - config: config, - logger: logger, + repo: repo, + config: config, + logger: logger, + sessStore: sessStore, oa2: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -104,11 +107,11 @@ func NewTraQProvider(repo repository.Repository, logger *zap.Logger, config TraQ } func (p *TraQProvider) LoginHandler(c echo.Context) error { - return defaultLoginHandler(&p.oa2)(c) + return defaultLoginHandler(p.sessStore, &p.oa2)(c) } func (p *TraQProvider) CallbackHandler(c echo.Context) error { - return defaultCallbackHandler(p, &p.oa2, p.repo, p.config.RegisterUserIfNotFound)(c) + return defaultCallbackHandler(p, &p.oa2, p.repo, p.sessStore, p.config.RegisterUserIfNotFound)(c) } func (p *TraQProvider) FetchUserInfo(t *oauth2.Token) (UserInfo, error) { diff --git a/router/middlewares/no_login.go b/router/middlewares/no_login.go index 86e66197d..82ccf3f8b 100644 --- a/router/middlewares/no_login.go +++ b/router/middlewares/no_login.go @@ -1,29 +1,25 @@ package middlewares import ( - "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" ) // NoLogin セッションが既に存在するリクエストを拒否するミドルウェア -func NoLogin() echo.MiddlewareFunc { +func NoLogin(sessStore session.Store) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if len(c.Request().Header.Get(echo.HeaderAuthorization)) > 0 { return herror.BadRequest("Authorization Header must not be set. Please logout once.") } - sess, err := sessions.Get(c.Response(), c.Request(), false) + sess, err := sessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } - if sess != nil { - if sess.GetUserID() != uuid.Nil { - return herror.BadRequest("You have already logged in. Please logout once.") - } - _ = sess.Destroy(c.Response(), c.Request()) + if sess != nil && sess.LoggedIn() { + return herror.BadRequest("You have already logged in. Please logout once.") } return next(c) diff --git a/router/middlewares/user_authenticate.go b/router/middlewares/user_authenticate.go index 18f2c4447..d855ffd6c 100644 --- a/router/middlewares/user_authenticate.go +++ b/router/middlewares/user_authenticate.go @@ -9,14 +9,14 @@ import ( "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "golang.org/x/sync/singleflight" ) const authScheme = "Bearer" // UserAuthenticate リクエスト認証ミドルウェア -func UserAuthenticate(repo repository.Repository) echo.MiddlewareFunc { +func UserAuthenticate(repo repository.Repository, sessStore session.Store) echo.MiddlewareFunc { var sfUser singleflight.Group return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -52,15 +52,15 @@ func UserAuthenticate(repo repository.Repository) echo.MiddlewareFunc { uid = token.UserID } else { // Authorizationヘッダーがないためセッションを確認する - sess, err := sessions.Get(c.Response(), c.Request(), false) + sess, err := sessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } - if sess == nil || sess.GetUserID() == uuid.Nil { + if sess == nil || !sess.LoggedIn() { return herror.Unauthorized("You are not logged in") } - uid = sess.GetUserID() + uid = sess.UserID() } // ユーザー取得 diff --git a/router/oauth2/authorization_endpoint.go b/router/oauth2/authorization_endpoint.go index 07f9f1250..4de10a9c3 100644 --- a/router/oauth2/authorization_endpoint.go +++ b/router/oauth2/authorization_endpoint.go @@ -9,7 +9,6 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" "github.com/traPtitech/traQ/utils/random" "github.com/traPtitech/traQ/utils/validator" "go.uber.org/zap" @@ -152,21 +151,20 @@ func (h *Handler) AuthorizationEndpointHandler(c echo.Context) error { req.Types = types // セッション確認 - se, err := sessions.Get(c.Response(), c.Request(), true) + se, err := h.SessStore.GetSession(c, true) if err != nil { h.L(c).Error(err.Error(), zap.Error(err)) q.Set("error", errServerError) redirectURI.RawQuery = q.Encode() return c.Redirect(http.StatusFound, redirectURI.String()) } - userID := se.GetUserID() switch req.Prompt { case "": break case "none": - u, err := h.Repo.GetUser(userID, false) + u, err := h.Repo.GetUser(se.UserID(), false) if err != nil { switch err { case repository.ErrNotFound: @@ -211,7 +209,7 @@ func (h *Handler) AuthorizationEndpointHandler(c echo.Context) error { data := &model.OAuth2Authorize{ Code: random.SecureAlphaNumeric(36), ClientID: req.ClientID, - UserID: userID, + UserID: se.UserID(), CreatedAt: time.Now(), ExpiresIn: authorizationCodeExp, RedirectURI: req.RedirectURI, @@ -278,7 +276,7 @@ func (h *Handler) AuthorizationDecideHandler(c echo.Context) error { } // セッション確認 - se, err := sessions.Get(c.Response(), c.Request(), false) + se, err := h.SessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } @@ -286,11 +284,14 @@ func (h *Handler) AuthorizationDecideHandler(c echo.Context) error { return herror.Forbidden("bad session") } - reqAuth, ok := se.Get(oauth2ContextSession).(authorizeRequest) - if !ok { + _reqAuth, err := se.Get(oauth2ContextSession) + if err != nil { + return herror.InternalServerError(err) + } + if _reqAuth == nil { return herror.Forbidden("bad session") } - userID := se.GetUserID() + reqAuth := _reqAuth.(authorizeRequest) if err := se.Delete(oauth2ContextSession); err != nil { return herror.InternalServerError(err) } @@ -334,7 +335,7 @@ func (h *Handler) AuthorizationDecideHandler(c echo.Context) error { data := &model.OAuth2Authorize{ Code: random.SecureAlphaNumeric(36), ClientID: reqAuth.ClientID, - UserID: userID, + UserID: se.UserID(), CreatedAt: time.Now(), ExpiresIn: authorizationCodeExp, RedirectURI: reqAuth.RedirectURI, diff --git a/router/oauth2/authorization_endpoint_test.go b/router/oauth2/authorization_endpoint_test.go index 514d50e25..680afd3ed 100644 --- a/router/oauth2/authorization_endpoint_test.go +++ b/router/oauth2/authorization_endpoint_test.go @@ -2,14 +2,12 @@ package oauth2 import ( "github.com/gofrs/uuid" - "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traPtitech/traQ/model" - "github.com/traPtitech/traQ/router/sessions" - random2 "github.com/traPtitech/traQ/utils/random" + "github.com/traPtitech/traQ/router/session" + "github.com/traPtitech/traQ/utils/random" "net/http" - "net/http/httptest" "testing" "time" ) @@ -39,30 +37,30 @@ func TestResponseType_valid(t *testing.T) { func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) - defaultUser := CreateUser(t, repo, rand) - session := S(t, defaultUser.GetID()) + env := Setup(t, db2) + defaultUser := env.CreateUser(t, rand) + s := env.S(t, defaultUser.GetID()) scopesRead := model.AccessScopes{} scopesRead.Add("read") client := &model.OAuth2Client{ - ID: random2.AlphaNumeric(36), + ID: random.AlphaNumeric(36), Name: "test client", Confidential: false, CreatorID: uuid.Must(uuid.NewV4()), - Secret: random2.AlphaNumeric(36), + Secret: random.AlphaNumeric(36), RedirectURI: "http://example.com", Scopes: scopesRead, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) t.Run("Success (prompt=none)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - user := CreateUser(t, repo, rand) - IssueToken(t, repo, client, user.GetID(), false) - e := R(t, server) + user := env.CreateUser(t, rand) + env.IssueToken(t, client, user.GetID(), false) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). @@ -70,7 +68,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { WithFormField("prompt", "none"). WithFormField("scope", "read"). WithFormField("nonce", "nonce"). - WithCookie(sessions.CookieName, S(t, user.GetID())). + WithCookie(session.CookieName, env.S(t, user.GetID())). Expect() res.Status(http.StatusFound) @@ -82,7 +80,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { assert.NotEmpty(loc.Query().Get("code")) } - a, err := repo.GetAuthorize(loc.Query().Get("code")) + a, err := env.Repository.GetAuthorize(loc.Query().Get("code")) if assert.NoError(err) { assert.Equal("nonce", a.Nonce) } @@ -91,7 +89,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Success (code)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). @@ -110,16 +108,18 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { assert.Equal("read", loc.Query().Get("scopes")) } - s, err := sessions.GetByToken(res.Cookie(sessions.CookieName).Value().Raw()) + s, err := env.SessStore.GetSessionByToken(res.Cookie(session.CookieName).Value().Raw()) if assert.NoError(err) { - assert.Equal("state", s.Get(oauth2ContextSession).(authorizeRequest).State) + v, err := s.Get(oauth2ContextSession) + assert.NoError(err) + assert.Equal("state", v.(authorizeRequest).State) } }) t.Run("Success (GET)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.GET("/oauth2/authorize"). WithQuery("client_id", client.ID). WithQuery("response_type", "code"). @@ -138,16 +138,18 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { assert.Equal("read", loc.Query().Get("scopes")) } - s, err := sessions.GetByToken(res.Cookie(sessions.CookieName).Value().Raw()) + s, err := env.SessStore.GetSessionByToken(res.Cookie(session.CookieName).Value().Raw()) if assert.NoError(err) { - assert.Equal("state", s.Get(oauth2ContextSession).(authorizeRequest).State) + v, err := s.Get(oauth2ContextSession) + assert.NoError(err) + assert.Equal("state", v.(authorizeRequest).State) } }) t.Run("Success With PKCE", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). @@ -168,16 +170,18 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { assert.Equal("read", loc.Query().Get("scopes")) } - s, err := sessions.GetByToken(res.Cookie(sessions.CookieName).Value().Raw()) + s, err := env.SessStore.GetSessionByToken(res.Cookie(session.CookieName).Value().Raw()) if assert.NoError(err) { - assert.Equal("state", s.Get(oauth2ContextSession).(authorizeRequest).State) - assert.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", s.Get(oauth2ContextSession).(authorizeRequest).CodeChallenge) + v, err := s.Get(oauth2ContextSession) + assert.NoError(err) + assert.Equal("state", v.(authorizeRequest).State) + assert.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", v.(authorizeRequest).CodeChallenge) } }) t.Run("Bad Request", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). Expect() res.Status(http.StatusBadRequest) @@ -187,7 +191,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Bad Request (no client)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", ""). Expect() @@ -198,7 +202,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Bad Request (unknown client)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", "unknown"). Expect() @@ -209,7 +213,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Bad Request (different redirect uri)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("redirect_uri", "http://example2.com"). @@ -222,7 +226,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (invalid pkce method)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("code_challenge_method", "S256"). @@ -240,7 +244,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (invalid scope)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("scope", "あいうえお"). @@ -257,7 +261,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (no valid scope)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("scope", "write"). @@ -274,7 +278,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (unknown response_type)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "aiueo"). @@ -291,7 +295,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (invalid response_type)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code token none"). @@ -308,7 +312,7 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (prompt=none with no session)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). @@ -326,12 +330,12 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (prompt=none without consent)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). WithFormField("prompt", "none"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect() res.Status(http.StatusFound) res.Header("Cache-Control").Equal("no-store") @@ -345,12 +349,12 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (invalid prompt)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). WithFormField("prompt", "ああああ"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect() res.Status(http.StatusFound) res.Header("Cache-Control").Equal("no-store") @@ -364,16 +368,16 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { t.Run("Found (prompt=none with broader scope)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - user := CreateUser(t, repo, rand) - _, err := repo.IssueToken(client, user.GetID(), client.RedirectURI, scopesRead, 1000, false) + user := env.CreateUser(t, rand) + _, err := env.Repository.IssueToken(client, user.GetID(), client.RedirectURI, scopesRead, 1000, false) require.NoError(t, err) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). WithFormField("response_type", "code"). WithFormField("prompt", "none"). WithFormField("scope", "read write"). - WithCookie(sessions.CookieName, S(t, user.GetID())). + WithCookie(session.CookieName, env.S(t, user.GetID())). Expect() res.Status(http.StatusFound) res.Header("Cache-Control").Equal("no-store") @@ -390,16 +394,16 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { scopes := model.AccessScopes{} scopes.Add("read", "write") client := &model.OAuth2Client{ - ID: random2.AlphaNumeric(36), + ID: random.AlphaNumeric(36), Name: "test client", Confidential: false, CreatorID: uuid.Must(uuid.NewV4()), - Secret: random2.AlphaNumeric(36), + Secret: random.AlphaNumeric(36), Scopes: scopes, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize"). WithFormField("client_id", client.ID). Expect() @@ -411,9 +415,9 @@ func TestHandlers_AuthorizationEndpointHandler(t *testing.T) { func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) - user := CreateUser(t, repo, rand) - session := S(t, user.GetID()) + env := Setup(t, db2) + user := env.CreateUser(t, rand) + s := env.S(t, user.GetID()) scopesRead := model.AccessScopes{} scopesRead.Add("read") @@ -421,44 +425,42 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { scopesReadWrite.Add("read", "write") client := &model.OAuth2Client{ - ID: random2.AlphaNumeric(36), + ID: random.AlphaNumeric(36), Name: "test client", Confidential: true, CreatorID: uuid.Must(uuid.NewV4()), - Secret: random2.AlphaNumeric(36), + Secret: random.AlphaNumeric(36), RedirectURI: "http://example.com", Scopes: scopesRead, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) MakeDecideSession := func(t *testing.T, uid uuid.UUID, client *model.OAuth2Client) string { - req := httptest.NewRequest(echo.GET, "/", nil) - rec := httptest.NewRecorder() - s, err := sessions.Get(rec, req, true) + s, err := env.SessStore.IssueSession(uid, map[string]interface{}{ + oauth2ContextSession: authorizeRequest{ + ResponseType: "code", + ClientID: client.ID, + RedirectURI: client.RedirectURI, + Scopes: scopesReadWrite, + ValidScopes: scopesRead, + State: "state", + Types: responseType{true, false, false}, + AccessTime: time.Now(), + Nonce: "nonce", + }, + }) require.NoError(t, err) - require.NoError(t, s.SetUser(uid)) - require.NoError(t, s.Set(oauth2ContextSession, authorizeRequest{ - ResponseType: "code", - ClientID: client.ID, - RedirectURI: client.RedirectURI, - Scopes: scopesReadWrite, - ValidScopes: scopesRead, - State: "state", - Types: responseType{true, false, false}, - AccessTime: time.Now(), - Nonce: "nonce", - })) - - return parseCookies(rec.Header().Get("Set-Cookie"))[sessions.CookieName].Value + + return s.Token() } t.Run("Success", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, MakeDecideSession(t, user.GetID(), client)). + WithCookie(session.CookieName, MakeDecideSession(t, user.GetID(), client)). Expect() res.Status(http.StatusFound) @@ -470,7 +472,7 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { assert.NotEmpty(loc.Query().Get("code")) } - a, err := repo.GetAuthorize(loc.Query().Get("code")) + a, err := env.Repository.GetAuthorize(loc.Query().Get("code")) if assert.NoError(err) { assert.Equal("nonce", a.Nonce) } @@ -478,9 +480,9 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Bad Request (No form)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). - WithCookie(sessions.CookieName, MakeDecideSession(t, user.GetID(), client)). + WithCookie(session.CookieName, MakeDecideSession(t, user.GetID(), client)). Expect() res.Status(http.StatusBadRequest) @@ -490,10 +492,10 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Forbidden (No oauth2ContextSession)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect() res.Status(http.StatusForbidden) @@ -503,10 +505,10 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Bad Request (client not found)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, MakeDecideSession(t, user.GetID(), &model.OAuth2Client{ID: "aaaa"})). + WithCookie(session.CookieName, MakeDecideSession(t, user.GetID(), &model.OAuth2Client{ID: "aaaa"})). Expect() res.Status(http.StatusBadRequest) @@ -517,18 +519,18 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Forbidden (client without redirect uri", func(t *testing.T) { t.Parallel() client := &model.OAuth2Client{ - ID: random2.AlphaNumeric(36), + ID: random.AlphaNumeric(36), Name: "test client", Confidential: true, CreatorID: uuid.Must(uuid.NewV4()), - Secret: random2.AlphaNumeric(36), + Secret: random.AlphaNumeric(36), Scopes: scopesRead, } - require.NoError(t, repo.SaveClient(client)) - e := R(t, server) + require.NoError(t, env.Repository.SaveClient(client)) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, MakeDecideSession(t, user.GetID(), client)). + WithCookie(session.CookieName, MakeDecideSession(t, user.GetID(), client)). Expect() res.Status(http.StatusForbidden) @@ -539,10 +541,10 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Found (deny)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "deny"). - WithCookie(sessions.CookieName, MakeDecideSession(t, user.GetID(), client)). + WithCookie(session.CookieName, MakeDecideSession(t, user.GetID(), client)). Expect() res.Status(http.StatusFound) @@ -557,25 +559,23 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Found (unsupported response type)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - req := httptest.NewRequest(echo.GET, "/", nil) - rec := httptest.NewRecorder() - s, err := sessions.Get(rec, req, true) + s, err := env.SessStore.IssueSession(user.GetID(), map[string]interface{}{ + oauth2ContextSession: authorizeRequest{ + ResponseType: "code", + ClientID: client.ID, + RedirectURI: client.RedirectURI, + Scopes: scopesReadWrite, + ValidScopes: scopesRead, + State: "state", + AccessTime: time.Now(), + }, + }) require.NoError(t, err) - require.NoError(t, s.SetUser(user.GetID())) - require.NoError(t, s.Set(oauth2ContextSession, authorizeRequest{ - ResponseType: "code", - ClientID: client.ID, - RedirectURI: client.RedirectURI, - Scopes: scopesReadWrite, - ValidScopes: scopesRead, - State: "state", - AccessTime: time.Now(), - })) - - e := R(t, server) + + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, parseCookies(rec.Header().Get("Set-Cookie"))[sessions.CookieName].Value). + WithCookie(session.CookieName, s.Token()). Expect() res.Status(http.StatusFound) @@ -590,25 +590,23 @@ func TestHandlers_AuthorizationDecideHandler(t *testing.T) { t.Run("Found (timeout)", func(t *testing.T) { t.Parallel() assert := assert.New(t) - req := httptest.NewRequest(echo.GET, "/", nil) - rec := httptest.NewRecorder() - s, err := sessions.Get(rec, req, true) + s, err := env.SessStore.IssueSession(user.GetID(), map[string]interface{}{ + oauth2ContextSession: authorizeRequest{ + ResponseType: "code", + ClientID: client.ID, + RedirectURI: client.RedirectURI, + Scopes: scopesReadWrite, + ValidScopes: scopesRead, + State: "state", + AccessTime: time.Now().Add(-6 * time.Minute), + }, + }) require.NoError(t, err) - require.NoError(t, s.SetUser(user.GetID())) - require.NoError(t, s.Set(oauth2ContextSession, authorizeRequest{ - ResponseType: "code", - ClientID: client.ID, - RedirectURI: client.RedirectURI, - Scopes: scopesReadWrite, - ValidScopes: scopesRead, - State: "state", - AccessTime: time.Now().Add(-6 * time.Minute), - })) - - e := R(t, server) + + e := env.R(t) res := e.POST("/oauth2/authorize/decide"). WithFormField("submit", "approve"). - WithCookie(sessions.CookieName, parseCookies(rec.Header().Get("Set-Cookie"))[sessions.CookieName].Value). + WithCookie(session.CookieName, s.Token()). Expect() res.Status(http.StatusFound) diff --git a/router/oauth2/oauth2.go b/router/oauth2/oauth2.go index b0f461ce6..c9ee101ce 100644 --- a/router/oauth2/oauth2.go +++ b/router/oauth2/oauth2.go @@ -7,6 +7,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/middlewares" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/service/rbac" "go.uber.org/zap" ) @@ -36,9 +37,10 @@ const ( ) type Handler struct { - RBAC rbac.RBAC - Repo repository.Repository - Logger *zap.Logger + RBAC rbac.RBAC + Repo repository.Repository + Logger *zap.Logger + SessStore session.Store Config } @@ -51,7 +53,7 @@ type Config struct { func (h *Handler) Setup(e *echo.Group) { e.GET("/authorize", h.AuthorizationEndpointHandler) - e.POST("/authorize/decide", h.AuthorizationDecideHandler, middlewares.UserAuthenticate(h.Repo), middlewares.BlockBot(h.Repo)) + e.POST("/authorize/decide", h.AuthorizationDecideHandler, middlewares.UserAuthenticate(h.Repo, h.SessStore), middlewares.BlockBot(h.Repo)) e.POST("/authorize", h.AuthorizationEndpointHandler) e.POST("/token", h.TokenEndpointHandler) e.POST("/revoke", h.RevokeTokenEndpointHandler) diff --git a/router/oauth2/oauth2_test.go b/router/oauth2/oauth2_test.go index ec391ceaf..55527b7e3 100644 --- a/router/oauth2/oauth2_test.go +++ b/router/oauth2/oauth2_test.go @@ -13,8 +13,8 @@ import ( "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" - "github.com/traPtitech/traQ/router/sessions" - rbac2 "github.com/traPtitech/traQ/service/rbac" + "github.com/traPtitech/traQ/router/session" + "github.com/traPtitech/traQ/service/rbac" "github.com/traPtitech/traQ/service/rbac/role" "github.com/traPtitech/traQ/utils/random" "github.com/traPtitech/traQ/utils/storage" @@ -33,12 +33,7 @@ const ( rand = "random" ) -var ( - servers = map[string]*httptest.Server{} - dbConns = map[string]*gorm.DB{} - repositories = map[string]repository.Repository{} - hubs = map[string]*hub.Hub{} -) +var envs = map[string]*Env{} func TestMain(m *testing.M) { user := getEnvOrDefault("MARIADB_USERNAME", "root") @@ -54,6 +49,8 @@ func TestMain(m *testing.M) { } for _, key := range dbs { + env := &Env{} + // テスト用データベース接続 db, err := gorm.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true", user, pass, host, port, fmt.Sprintf("%s%s", dbPrefix, key))) if err != nil { @@ -63,20 +60,20 @@ func TestMain(m *testing.M) { if err := migration.DropAll(db); err != nil { panic(err) } - dbConns[key] = db - hub := hub.New() - hubs[key] = hub + env.DB = db + env.Hub = hub.New() + env.SessStore = session.NewMemorySessionStore() // テスト用リポジトリ作成 - repo, err := repository.NewGormRepository(db, storage.NewInMemoryFileStorage(), hub, zap.NewNop()) + repo, err := repository.NewGormRepository(db, storage.NewInMemoryFileStorage(), env.Hub, zap.NewNop()) if err != nil { panic(err) } if _, err := repo.Sync(); err != nil { panic(err) } - repositories[key] = repo + env.Repository = repo // テスト用サーバー作成 e := echo.New() @@ -85,66 +82,69 @@ func TestMain(m *testing.M) { e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) e.Use(extension.Wrap(repo)) - r, err := rbac2.New(db) + r, err := rbac.New(db) if err != nil { panic(err) } config := &Handler{ - RBAC: r, - Repo: repo, - Logger: zap.NewNop(), + RBAC: r, + Repo: env.Repository, + SessStore: env.SessStore, + Logger: zap.NewNop(), Config: Config{ AccessTokenExp: 1000, IsRefreshEnabled: true, }, } config.Setup(e.Group("/oauth2")) - servers[key] = httptest.NewServer(e) + env.Server = httptest.NewServer(e) + + envs[key] = env } // テスト実行 code := m.Run() // 後始末 - for _, v := range servers { - v.Close() - } - for _, v := range dbConns { - v.Close() - } - for _, v := range hubs { - v.Close() + for _, env := range envs { + env.Server.Close() + env.DB.Close() + env.Hub.Close() } os.Exit(code) } +type Env struct { + Server *httptest.Server + DB *gorm.DB + Repository repository.Repository + Hub *hub.Hub + SessStore session.Store +} + // Setup テストセットアップ -func Setup(t *testing.T, server string) (repository.Repository, *httptest.Server) { +func Setup(t *testing.T, server string) *Env { t.Helper() - s, ok := servers[server] + env, ok := envs[server] if !ok { t.FailNow() } - repo := repositories[server] - return repo, s + return env } // S 指定ユーザーのAPIセッショントークンを発行 -func S(t *testing.T, userID uuid.UUID) string { +func (env *Env) S(t *testing.T, userID uuid.UUID) string { t.Helper() - require := require.New(t) - - sess, err := sessions.IssueNewSession("127.0.0.1", "test") - require.NoError(err) - require.NoError(sess.SetUser(userID)) - return sess.GetToken() + s, err := env.SessStore.IssueSession(userID, nil) + require.NoError(t, err) + return s.Token() } // R リクエストテスターを作成 -func R(t *testing.T, server *httptest.Server) *httpexpect.Expect { +func (env *Env) R(t *testing.T) *httpexpect.Expect { t.Helper() return httpexpect.WithConfig(httpexpect.Config{ - BaseURL: server.URL, + BaseURL: env.Server.URL, Reporter: httpexpect.NewAssertReporter(t), Printers: []httpexpect.Printer{ httpexpect.NewCurlPrinter(t), @@ -161,24 +161,24 @@ func R(t *testing.T, server *httptest.Server) *httpexpect.Expect { } // CreateUser ユーザーを必ず作成します -func CreateUser(t *testing.T, repo repository.Repository, userName string) model.UserInfo { +func (env *Env) CreateUser(t *testing.T, userName string) model.UserInfo { t.Helper() if userName == rand { userName = random.AlphaNumeric(32) } - u, err := repo.CreateUser(repository.CreateUserArgs{Name: userName, Password: "testtesttesttest", Role: role.User}) + u, err := env.Repository.CreateUser(repository.CreateUserArgs{Name: userName, Password: "testtesttesttest", Role: role.User}) require.NoError(t, err) return u } -func IssueToken(t *testing.T, repo repository.Repository, client *model.OAuth2Client, userID uuid.UUID, refresh bool) *model.OAuth2Token { +func (env *Env) IssueToken(t *testing.T, client *model.OAuth2Client, userID uuid.UUID, refresh bool) *model.OAuth2Token { t.Helper() - token, err := repo.IssueToken(client, userID, client.RedirectURI, client.Scopes, 1000, refresh) + token, err := env.Repository.IssueToken(client, userID, client.RedirectURI, client.Scopes, 1000, refresh) require.NoError(t, err) return token } -func MakeAuthorizeData(t *testing.T, repo repository.Repository, clientID string, userID uuid.UUID) *model.OAuth2Authorize { +func (env *Env) MakeAuthorizeData(t *testing.T, clientID string, userID uuid.UUID) *model.OAuth2Authorize { t.Helper() scopes := model.AccessScopes{} scopes.Add("read") @@ -193,7 +193,7 @@ func MakeAuthorizeData(t *testing.T, repo repository.Repository, clientID string OriginalScopes: scopes, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) return authorize } @@ -204,11 +204,3 @@ func getEnvOrDefault(env string, def string) string { } return s } - -func parseCookies(value string) map[string]*http.Cookie { - m := map[string]*http.Cookie{} - for _, c := range (&http.Request{Header: http.Header{"Cookie": {value}}}).Cookies() { - m[c.Name] = c - } - return m -} diff --git a/router/oauth2/revoke_token_endpoint_test.go b/router/oauth2/revoke_token_endpoint_test.go index 8d64c6c56..09862dfa6 100644 --- a/router/oauth2/revoke_token_endpoint_test.go +++ b/router/oauth2/revoke_token_endpoint_test.go @@ -11,12 +11,12 @@ import ( func TestHandlers_RevokeTokenEndpointHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db1) - user := CreateUser(t, repo, rand) + env := Setup(t, db1) + user := env.CreateUser(t, rand) t.Run("NoToken", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.POST("/oauth2/revoke"). WithFormField("token", ""). Expect(). @@ -25,31 +25,31 @@ func TestHandlers_RevokeTokenEndpointHandler(t *testing.T) { t.Run("AccessToken", func(t *testing.T) { t.Parallel() - token, err := repo.IssueToken(nil, user.GetID(), "", model.AccessScopes{}, 10000, false) + token, err := env.Repository.IssueToken(nil, user.GetID(), "", model.AccessScopes{}, 10000, false) require.NoError(t, err) - e := R(t, server) + e := env.R(t) e.POST("/oauth2/revoke"). WithFormField("token", token.AccessToken). Expect(). Status(http.StatusOK) - _, err = repo.GetTokenByID(token.ID) + _, err = env.Repository.GetTokenByID(token.ID) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("RefreshToken", func(t *testing.T) { t.Parallel() - token, err := repo.IssueToken(nil, user.GetID(), "", model.AccessScopes{}, 10000, true) + token, err := env.Repository.IssueToken(nil, user.GetID(), "", model.AccessScopes{}, 10000, true) require.NoError(t, err) - e := R(t, server) + e := env.R(t) e.POST("/oauth2/revoke"). WithFormField("token", token.RefreshToken). Expect(). Status(http.StatusOK) - _, err = repo.GetTokenByID(token.ID) + _, err = env.Repository.GetTokenByID(token.ID) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) } diff --git a/router/oauth2/token_endpoint_test.go b/router/oauth2/token_endpoint_test.go index d8d0ca6a7..4cfb906af 100644 --- a/router/oauth2/token_endpoint_test.go +++ b/router/oauth2/token_endpoint_test.go @@ -14,11 +14,11 @@ import ( func TestHandlers_TokenEndpointHandler(t *testing.T) { t.Parallel() - _, server := Setup(t, db2) + env := Setup(t, db2) t.Run("Unsupported Grant Type", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", "ああああ"). Expect() @@ -32,7 +32,7 @@ func TestHandlers_TokenEndpointHandler(t *testing.T) { func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) + env := Setup(t, db2) scopesReadWrite := model.AccessScopes{} scopesReadWrite.Add("read", "write") @@ -45,11 +45,11 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) t.Run("Success with Basic Auth", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithBasicAuth(client.ID, client.Secret). @@ -70,7 +70,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Success with form Auth", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithFormField("client_id", client.ID). @@ -92,7 +92,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Success with smaller scope", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithFormField("scope", "read"). @@ -112,7 +112,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Success with invalid scope", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithFormField("scope", "read manage_bot"). @@ -132,7 +132,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Invalid Client (No credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). Expect() @@ -145,7 +145,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Invalid Client (Wrong credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithBasicAuth(client.ID, "wrong password"). @@ -159,7 +159,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Invalid Client (Unknown client)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithBasicAuth("wrong client", "wrong password"). @@ -182,8 +182,8 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) - e := R(t, server) + require.NoError(t, env.Repository.SaveClient(client)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithBasicAuth(client.ID, client.Secret). @@ -197,7 +197,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Invalid Scope (unknown scope)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithFormField("scope", "アイウエオ"). @@ -212,7 +212,7 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { t.Run("Invalid Scope (no valid scope)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeClientCredentials). WithFormField("scope", "manage_bot"). @@ -228,8 +228,8 @@ func TestHandlers_TokenEndpointClientCredentialsHandler(t *testing.T) { func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) - user := CreateUser(t, repo, rand) + env := Setup(t, db2) + user := env.CreateUser(t, rand) scopesReadWrite := model.AccessScopes{} scopesReadWrite.Add("read", "write") @@ -242,11 +242,11 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) t.Run("Success with Basic Auth", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -269,7 +269,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Success with form Auth", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -293,7 +293,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Success with smaller scope", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -315,7 +315,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Success with invalid scope", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -346,8 +346,8 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) - e := R(t, server) + require.NoError(t, env.Repository.SaveClient(client)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -370,7 +370,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Request (No user credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithBasicAuth(client.ID, client.Secret). @@ -384,7 +384,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Grant (Wrong user credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -400,7 +400,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Client (No client credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -415,7 +415,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Client (Wrong client credentials)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -431,7 +431,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Client (Unknown client)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -447,7 +447,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Scope (unknown scope)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -464,7 +464,7 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { t.Run("Invalid Scope (no valid scope)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypePassword). WithFormField("username", user.GetName()). @@ -482,8 +482,8 @@ func TestHandlers_TokenEndpointPasswordHandler(t *testing.T) { func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) - user := CreateUser(t, repo, rand) + env := Setup(t, db2) + user := env.CreateUser(t, rand) scopesReadWrite := model.AccessScopes{} scopesReadWrite.Add("read", "write") @@ -496,7 +496,7 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) clientConf := &model.OAuth2Client{ ID: random2.AlphaNumeric(36), @@ -507,12 +507,12 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(clientConf)) + require.NoError(t, env.Repository.SaveClient(clientConf)) t.Run("Success", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, client, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, client, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -528,14 +528,14 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetTokenByRefresh(token.RefreshToken) + _, err := env.Repository.GetTokenByRefresh(token.RefreshToken) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with smaller scope", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, client, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, client, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -552,14 +552,14 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.Value("scope").String().Equal("read") - _, err := repo.GetTokenByRefresh(token.RefreshToken) + _, err := env.Repository.GetTokenByRefresh(token.RefreshToken) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with invalid scope", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, client, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, client, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -576,14 +576,14 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.Value("scope").String().Equal("read") - _, err := repo.GetTokenByRefresh(token.RefreshToken) + _, err := env.Repository.GetTokenByRefresh(token.RefreshToken) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with confidential client Basic Auth", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, clientConf, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, clientConf, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -600,14 +600,14 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetTokenByRefresh(token.RefreshToken) + _, err := env.Repository.GetTokenByRefresh(token.RefreshToken) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with confidential client form Auth", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, clientConf, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, clientConf, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -625,13 +625,13 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetTokenByRefresh(token.RefreshToken) + _, err := env.Repository.GetTokenByRefresh(token.RefreshToken) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Request (No refresh token)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). Expect() @@ -644,7 +644,7 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Run("Invalid Grant (Unknown refresh token)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", "unknown token"). @@ -658,8 +658,8 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Run("Invalid Client (No client credentials)", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, clientConf, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, clientConf, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -673,8 +673,8 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Run("Invalid Client (Wrong client credentials)", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, clientConf, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, clientConf, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -689,8 +689,8 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Run("Invalid Scope (unknown scope)", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, client, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, client, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -705,8 +705,8 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { t.Run("Invalid Scope (no valid scope)", func(t *testing.T) { t.Parallel() - token := IssueToken(t, repo, client, user.GetID(), true) - e := R(t, server) + token := env.IssueToken(t, client, user.GetID(), true) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeRefreshToken). WithFormField("refresh_token", token.RefreshToken). @@ -722,8 +722,8 @@ func TestHandlers_TokenEndpointRefreshTokenHandler(t *testing.T) { func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { t.Parallel() - repo, server := Setup(t, db2) - user := CreateUser(t, repo, rand) + env := Setup(t, db2) + user := env.CreateUser(t, rand) scopesReadWrite := model.AccessScopes{} scopesReadWrite.Add("read", "write") @@ -740,7 +740,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(client)) + require.NoError(t, env.Repository.SaveClient(client)) clientConf := &model.OAuth2Client{ ID: random2.AlphaNumeric(36), @@ -751,13 +751,13 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { RedirectURI: "http://example.com", Scopes: scopesReadWrite, } - require.NoError(t, repo.SaveClient(clientConf)) + require.NoError(t, env.Repository.SaveClient(clientConf)) t.Run("Success", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, client.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, client.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -775,14 +775,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with confidential client Basic Auth", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -800,14 +800,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Success with confidential client form Auth", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -826,7 +826,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -845,8 +845,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { CodeChallengeMethod: "plain", CodeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -865,7 +865,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -884,8 +884,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { CodeChallengeMethod: "S256", CodeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -904,7 +904,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -921,8 +921,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { OriginalScopes: scopesRead, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -940,7 +940,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { obj.Value("refresh_token").String().NotEmpty() obj.NotContainsKey("scope") - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -957,8 +957,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { OriginalScopes: scopesReadManageBot, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -978,13 +978,13 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { actual.FromString(obj.Value("scope").String().Raw()) assert.ElementsMatch(t, authorize.Scopes.StringArray(), actual.StringArray()) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Request (No code)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("redirect_uri", "http://example.com"). @@ -999,8 +999,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { t.Run("Invalid Client (No client)", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1012,14 +1012,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidClient) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Client (Wrong client credentials)", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1032,14 +1032,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidClient) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Client (Other client)", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1052,13 +1052,13 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidClient) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Grant (Wrong code)", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", "unknown"). @@ -1085,8 +1085,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { OriginalScopes: scopesReadWrite, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1099,7 +1099,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidGrant) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -1116,8 +1116,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { OriginalScopes: scopesReadWrite, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1130,14 +1130,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidClient) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Grant (different redirect)", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1150,7 +1150,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidGrant) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -1166,8 +1166,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { OriginalScopes: scopesReadWrite, Nonce: "nonce", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1180,7 +1180,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidGrant) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) @@ -1199,8 +1199,8 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { CodeChallengeMethod: "plain", CodeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", } - require.NoError(t, repo.SaveAuthorize(authorize)) - e := R(t, server) + require.NoError(t, env.Repository.SaveAuthorize(authorize)) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1213,14 +1213,14 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidRequest) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) t.Run("Invalid Request (unexpected PKCE)", func(t *testing.T) { t.Parallel() - authorize := MakeAuthorizeData(t, repo, clientConf.ID, user.GetID()) - e := R(t, server) + authorize := env.MakeAuthorizeData(t, clientConf.ID, user.GetID()) + e := env.R(t) res := e.POST("/oauth2/token"). WithFormField("grant_type", grantTypeAuthorizationCode). WithFormField("code", authorize.Code). @@ -1234,7 +1234,7 @@ func TestHandlers_TokenEndpointAuthorizationCodeHandler(t *testing.T) { res.Header("Pragma").Equal("no-cache") res.JSON().Object().Value("error").Equal(errInvalidRequest) - _, err := repo.GetAuthorize(authorize.Code) + _, err := env.Repository.GetAuthorize(authorize.Code) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) } diff --git a/router/router.go b/router/router.go index d1357f390..f5a001543 100644 --- a/router/router.go +++ b/router/router.go @@ -1,6 +1,7 @@ package router import ( + "github.com/jinzhu/gorm" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/leandro-lugaresi/hub" @@ -11,6 +12,7 @@ import ( "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/middlewares" "github.com/traPtitech/traQ/router/oauth2" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/router/v1" "github.com/traPtitech/traQ/router/v3" "github.com/traPtitech/traQ/service" @@ -19,14 +21,15 @@ import ( ) type Router struct { - e *echo.Echo - v1 *v1.Handlers - v3 *v3.Handlers - oauth2 *oauth2.Handler + e *echo.Echo + sessStore session.Store + v1 *v1.Handlers + v3 *v3.Handlers + oauth2 *oauth2.Handler } -func Setup(hub *hub.Hub, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *echo.Echo { - r := newRouter(hub, repo, ss, logger.Named("router"), config) +func Setup(hub *hub.Hub, db *gorm.DB, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *echo.Echo { + r := newRouter(hub, db, repo, ss, logger.Named("router"), config) api := r.e.Group("/api") api.GET("/metrics", echo.WrapHandler(promhttp.Handler())) @@ -40,22 +43,22 @@ func Setup(hub *hub.Hub, repo repository.Repository, ss *service.Services, logge // 外部authハンドラ extAuth := api.Group("/auth") if config.ExternalAuth.GitHub.Valid() { - p := auth.NewGithubProvider(repo, logger.Named("ext_auth"), config.ExternalAuth.GitHub) + p := auth.NewGithubProvider(repo, logger.Named("ext_auth"), r.sessStore, config.ExternalAuth.GitHub) extAuth.GET("/github", p.LoginHandler) extAuth.GET("/github/callback", p.CallbackHandler) } if config.ExternalAuth.Google.Valid() { - p := auth.NewGoogleProvider(repo, logger.Named("ext_auth"), config.ExternalAuth.Google) + p := auth.NewGoogleProvider(repo, logger.Named("ext_auth"), r.sessStore, config.ExternalAuth.Google) extAuth.GET("/google", p.LoginHandler) extAuth.GET("/google/callback", p.CallbackHandler) } if config.ExternalAuth.TraQ.Valid() { - p := auth.NewTraQProvider(repo, logger.Named("ext_auth"), config.ExternalAuth.TraQ) + p := auth.NewTraQProvider(repo, logger.Named("ext_auth"), r.sessStore, config.ExternalAuth.TraQ) extAuth.GET("/traq", p.LoginHandler) extAuth.GET("/traq/callback", p.CallbackHandler) } if config.ExternalAuth.OIDC.Valid() { - p, err := auth.NewOIDCProvider(repo, logger.Named("ext_auth"), config.ExternalAuth.OIDC) + p, err := auth.NewOIDCProvider(repo, logger.Named("ext_auth"), r.sessStore, config.ExternalAuth.OIDC) if err != nil { panic(err) } diff --git a/router/router_wire.go b/router/router_wire.go index 434eb91c7..73e10ef24 100644 --- a/router/router_wire.go +++ b/router/router_wire.go @@ -4,21 +4,24 @@ package router import ( "github.com/google/wire" + "github.com/jinzhu/gorm" "github.com/leandro-lugaresi/hub" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/oauth2" + "github.com/traPtitech/traQ/router/session" v1 "github.com/traPtitech/traQ/router/v1" v3 "github.com/traPtitech/traQ/router/v3" "github.com/traPtitech/traQ/service" "go.uber.org/zap" ) -func newRouter(hub *hub.Hub, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *Router { +func newRouter(hub *hub.Hub, db *gorm.DB, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *Router { wire.Build( service.ProviderSet, newEcho, provideOAuth2Config, provideV3Config, + session.NewGormStore, wire.Struct(new(v1.Handlers), "*"), wire.Struct(new(v3.Handlers), "*"), wire.Struct(new(oauth2.Handler), "*"), diff --git a/router/session/gorm.go b/router/session/gorm.go new file mode 100644 index 000000000..82f59f01a --- /dev/null +++ b/router/session/gorm.go @@ -0,0 +1,348 @@ +package session + +import ( + "bytes" + "encoding/gob" + "github.com/gofrs/uuid" + lru "github.com/hashicorp/golang-lru" + "github.com/jinzhu/gorm" + "github.com/labstack/echo/v4" + "github.com/traPtitech/traQ/model" + "github.com/traPtitech/traQ/utils/random" + "net/http" + "sync" + "time" +) + +func init() { + gob.Register(map[string]interface{}{}) +} + +type session struct { + t string + refID uuid.UUID + userID uuid.UUID + createdAt time.Time + + loaded bool + db *gorm.DB + data map[string]interface{} + sync.Mutex +} + +func newSession(db *gorm.DB, t string, refID uuid.UUID, userID uuid.UUID, createdAt time.Time, data map[string]interface{}) *session { + return &session{ + t: t, + refID: refID, + userID: userID, + createdAt: createdAt, + loaded: data != nil, + db: db, + data: data, + } +} + +func (s *session) Token() string { + return s.t +} + +func (s *session) RefID() uuid.UUID { + return s.refID +} + +func (s *session) UserID() uuid.UUID { + return s.userID +} + +func (s *session) CreatedAt() time.Time { + return s.createdAt +} + +func (s *session) LoggedIn() bool { + return s.userID != uuid.Nil +} + +func (s *session) Get(key string) (interface{}, error) { + s.Lock() + defer s.Unlock() + if !s.loaded { + if err := s.load(); err != nil { + return nil, err + } + } + v := s.data[key] + return v, nil +} + +func (s *session) Set(key string, value interface{}) error { + s.Lock() + defer s.Unlock() + if !s.loaded { + if err := s.load(); err != nil { + return err + } + } + s.data[key] = value + return s.save() +} + +func (s *session) Delete(key string) error { + s.Lock() + defer s.Unlock() + if !s.loaded { + if err := s.load(); err != nil { + return err + } + } + delete(s.data, key) + return s.save() +} + +func (s *session) Expired() bool { + return time.Since(s.createdAt) > time.Duration(sessionMaxAge)*time.Second +} + +func (s *session) Refreshable() bool { + return time.Since(s.createdAt) <= time.Duration(sessionMaxAge+sessionKeepAge)*time.Second +} + +func (s *session) load() error { + var r struct { + Data []byte `gorm:"type:longblob"` + } + + if err := s.db.Model(&model.SessionRecord{Token: s.t}).Select("data").Scan(&r).Error; err != nil { + return err + } + + if err := gob.NewDecoder(bytes.NewReader(r.Data)).Decode(&s.data); err != nil { + return err + } + + s.loaded = true + return nil +} + +func (s *session) save() error { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(s.data); err != nil { + panic(err) // gobにdataの中身の構造体が登録されていない + } + return s.db.Model(&model.SessionRecord{Token: s.t}).Update("data", buf.Bytes()).Error +} + +type cachedSession struct { + t string + refID uuid.UUID + userID uuid.UUID + createdAt time.Time +} + +type sessionStore struct { + db *gorm.DB + cache *lru.Cache +} + +func NewGormStore(db *gorm.DB) Store { + cache, _ := lru.New(cacheSize) + return &sessionStore{ + db: db, + cache: cache, + } +} + +func (ss *sessionStore) GetSession(c echo.Context, createIfNotExist bool) (Session, error) { + var token string + cookie, err := c.Cookie(CookieName) + if err == nil { + token = cookie.Value + } + + var s Session + if len(token) > 0 { + s, err = ss.GetSessionByToken(token) + if err != nil && err != ErrSessionNotFound { + return nil, err + } + } + + if s != nil { + if !s.Expired() { + return s, nil + } + if s.Refreshable() { + return ss.RenewSession(c, s.UserID()) + } + } + + if !createIfNotExist { + return nil, ss.RevokeSession(c) + } + + // セッション発行 + return ss.RenewSession(c, uuid.Nil) +} + +func (ss *sessionStore) GetSessionByToken(token string) (Session, error) { + if len(token) == 0 { + return nil, ErrSessionNotFound + } + + if _v, ok := ss.cache.Get(token); ok { + v := _v.(*cachedSession) + return newSession(ss.db, v.t, v.refID, v.userID, v.createdAt, nil), nil + } + + var r model.SessionRecord + err := ss.db.First(&r, &model.SessionRecord{Token: token}).Error + if err == nil { + if r.UserID != uuid.Nil { + ss.cache.Add(r.Token, &cachedSession{t: r.Token, refID: r.ReferenceID, userID: r.UserID, createdAt: r.Created}) + } + + data, err := r.GetData() + if err != nil { + return nil, err + } + return newSession(ss.db, r.Token, r.ReferenceID, r.UserID, r.Created, data), nil + } + + if gorm.IsRecordNotFoundError(err) { + return nil, ErrSessionNotFound + } + return nil, err +} + +func (ss *sessionStore) GetSessionsByUserID(userID uuid.UUID) ([]Session, error) { + if userID == uuid.Nil { + return []Session{}, nil + } + + var records []*model.SessionRecord + if err := ss.db.Find(&records, &model.SessionRecord{UserID: userID}).Error; err != nil { + return nil, err + } + + result := make([]Session, 0) + for _, r := range records { + data, err := r.GetData() + if err != nil { + return nil, err + } + s := newSession(ss.db, r.Token, r.ReferenceID, r.UserID, r.Created, data) + if s.Refreshable() { + result = append(result, s) + } + } + return result, nil +} + +func (ss *sessionStore) RevokeSession(c echo.Context) error { + cookie, err := c.Cookie(CookieName) + if err != nil { + return nil + } + + if err := ss.db.Delete(&model.SessionRecord{Token: cookie.Value}).Error; err != nil { + return err + } + ss.cache.Remove(cookie.Value) + + cookie.Value = "" + cookie.Expires = time.Unix(0, 0) + cookie.MaxAge = -1 + c.SetCookie(cookie) + return nil +} + +func (ss *sessionStore) RevokeSessionByRefID(refID uuid.UUID) error { + if refID == uuid.Nil { + return nil + } + + var r model.SessionRecord + if err := ss.db.First(&r, &model.SessionRecord{ReferenceID: refID}).Error; err != nil { + if gorm.IsRecordNotFoundError(err) { + return nil + } + return err + } + if err := ss.db.Delete(&model.SessionRecord{Token: r.Token}).Error; err != nil { + return err + } + ss.cache.Remove(r.Token) + + return nil +} + +func (ss *sessionStore) RevokeSessionsByUserID(userID uuid.UUID) error { + if userID == uuid.Nil { + return nil + } + + var rs []*model.SessionRecord + if err := ss.db.Find(&rs, &model.SessionRecord{UserID: userID}).Error; err != nil { + return err + } + if err := ss.db.Delete(&model.SessionRecord{UserID: userID}).Error; err != nil { + return err + } + + for _, r := range rs { + ss.cache.Remove(r.Token) + } + return nil +} + +func (ss *sessionStore) RenewSession(c echo.Context, userID uuid.UUID) (Session, error) { + cookie, _ := c.Cookie(CookieName) + if cookie != nil && len(cookie.Value) > 0 { + if err := ss.db.Delete(&model.SessionRecord{Token: cookie.Value}).Error; err != nil { + return nil, err + } + ss.cache.Remove(cookie.Value) + } else { + cookie = &http.Cookie{} + } + + s, err := ss.IssueSession(userID, nil) + if err != nil { + return nil, err + } + + cookie.Name = CookieName + cookie.Value = s.Token() + cookie.Expires = time.Now().Add(time.Duration(sessionMaxAge+sessionKeepAge) * time.Second) + cookie.MaxAge = sessionMaxAge + sessionKeepAge + cookie.Path = "/" + cookie.HttpOnly = true + c.SetCookie(cookie) + + return s, nil +} + +func (ss *sessionStore) IssueSession(userID uuid.UUID, data map[string]interface{}) (Session, error) { + if data == nil { + data = map[string]interface{}{} + } + + s := &model.SessionRecord{ + Token: random.SecureAlphaNumeric(50), + ReferenceID: uuid.Must(uuid.NewV4()), + UserID: userID, + Created: time.Now(), + } + s.SetData(data) + + if err := ss.db.Create(s).Error; err != nil { + return nil, err + } + ss.cache.Add(s.Token, &cachedSession{ + t: s.Token, + refID: s.ReferenceID, + userID: s.UserID, + createdAt: s.Created, + }) + + return newSession(ss.db, s.Token, s.ReferenceID, s.UserID, s.Created, data), nil +} diff --git a/router/session/memory.go b/router/session/memory.go new file mode 100644 index 000000000..c9fcc083d --- /dev/null +++ b/router/session/memory.go @@ -0,0 +1,233 @@ +package session + +import ( + "github.com/gofrs/uuid" + "github.com/labstack/echo/v4" + "github.com/traPtitech/traQ/utils/random" + "net/http" + "sync" + "time" +) + +type memorySession struct { + t string + refID uuid.UUID + userID uuid.UUID + createdAt time.Time + data map[string]interface{} + sync.Mutex +} + +func newMemorySession(t string, refID uuid.UUID, userID uuid.UUID, createdAt time.Time, data map[string]interface{}) *memorySession { + return &memorySession{ + t: t, + refID: refID, + userID: userID, + createdAt: createdAt, + data: data, + } +} + +func (s *memorySession) Token() string { + return s.t +} + +func (s *memorySession) RefID() uuid.UUID { + return s.refID +} + +func (s *memorySession) UserID() uuid.UUID { + return s.userID +} + +func (s *memorySession) CreatedAt() time.Time { + return s.createdAt +} + +func (s *memorySession) LoggedIn() bool { + return s.userID != uuid.Nil +} + +func (s *memorySession) Get(key string) (interface{}, error) { + s.Lock() + defer s.Unlock() + return s.data[key], nil +} + +func (s *memorySession) Set(key string, value interface{}) error { + s.Lock() + defer s.Unlock() + s.data[key] = value + return nil +} + +func (s *memorySession) Delete(key string) error { + s.Lock() + defer s.Unlock() + delete(s.data, key) + return nil +} + +func (s *memorySession) Expired() bool { + return time.Since(s.createdAt) > time.Duration(sessionMaxAge)*time.Second +} + +func (s *memorySession) Refreshable() bool { + return time.Since(s.createdAt) <= time.Duration(sessionMaxAge+sessionKeepAge)*time.Second +} + +type memoryStore struct { + sessions map[string]*memorySession + sync.RWMutex +} + +func NewMemorySessionStore() Store { + return &memoryStore{ + sessions: map[string]*memorySession{}, + } +} + +func (ms *memoryStore) GetSession(c echo.Context, createIfNotExist bool) (Session, error) { + var token string + cookie, err := c.Cookie(CookieName) + if err == nil { + token = cookie.Value + } + + var s Session + if len(token) > 0 { + s, err = ms.GetSessionByToken(token) + if err != nil && err != ErrSessionNotFound { + return nil, err + } + } + + if s != nil { + if !s.Expired() { + return s, nil + } + if s.Refreshable() { + return ms.RenewSession(c, s.UserID()) + } + } + + if !createIfNotExist { + return nil, ms.RevokeSession(c) + } + + // セッション発行 + return ms.RenewSession(c, uuid.Nil) +} + +func (ms *memoryStore) GetSessionByToken(token string) (Session, error) { + if len(token) == 0 { + return nil, ErrSessionNotFound + } + + ms.RLock() + defer ms.RUnlock() + s, ok := ms.sessions[token] + if !ok { + return nil, ErrSessionNotFound + } + return s, nil +} + +func (ms *memoryStore) GetSessionsByUserID(userID uuid.UUID) ([]Session, error) { + if userID == uuid.Nil { + return []Session{}, nil + } + + ms.RLock() + defer ms.RUnlock() + result := make([]Session, 0) + for _, s := range ms.sessions { + if s.UserID() == userID && s.Refreshable() { + result = append(result, s) + } + } + return result, nil +} + +func (ms *memoryStore) RevokeSession(c echo.Context) error { + cookie, err := c.Cookie(CookieName) + if err != nil { + return nil + } + + ms.Lock() + delete(ms.sessions, cookie.Value) + ms.Unlock() + + cookie.Value = "" + cookie.Expires = time.Unix(0, 0) + cookie.MaxAge = -1 + c.SetCookie(cookie) + return nil +} + +func (ms *memoryStore) RevokeSessionByRefID(refID uuid.UUID) error { + if refID == uuid.Nil { + return nil + } + ms.Lock() + defer ms.Unlock() + for k, s := range ms.sessions { + if s.RefID() == refID { + delete(ms.sessions, k) + return nil + } + } + return nil +} + +func (ms *memoryStore) RevokeSessionsByUserID(userID uuid.UUID) error { + if userID == uuid.Nil { + return nil + } + ms.Lock() + defer ms.Unlock() + for k, s := range ms.sessions { + if s.UserID() == userID { + delete(ms.sessions, k) + } + } + return nil +} + +func (ms *memoryStore) RenewSession(c echo.Context, userID uuid.UUID) (Session, error) { + cookie, _ := c.Cookie(CookieName) + if cookie != nil && len(cookie.Value) > 0 { + ms.Lock() + delete(ms.sessions, cookie.Value) + ms.Unlock() + } else { + cookie = &http.Cookie{} + } + + s, err := ms.IssueSession(userID, nil) + if err != nil { + return nil, err + } + + cookie.Name = CookieName + cookie.Value = s.Token() + cookie.Expires = time.Now().Add(time.Duration(sessionMaxAge+sessionKeepAge) * time.Second) + cookie.MaxAge = sessionMaxAge + sessionKeepAge + cookie.Path = "/" + cookie.HttpOnly = true + c.SetCookie(cookie) + + return s, nil +} + +func (ms *memoryStore) IssueSession(userID uuid.UUID, data map[string]interface{}) (Session, error) { + if data == nil { + data = map[string]interface{}{} + } + s := newMemorySession(random.SecureAlphaNumeric(50), uuid.Must(uuid.NewV4()), userID, time.Now(), data) + ms.Lock() + ms.sessions[s.Token()] = s + ms.Unlock() + return s, nil +} diff --git a/router/session/session.go b/router/session/session.go new file mode 100644 index 000000000..4ee8ef802 --- /dev/null +++ b/router/session/session.go @@ -0,0 +1,44 @@ +package session + +import ( + "errors" + "github.com/gofrs/uuid" + "github.com/labstack/echo/v4" + "time" +) + +const ( + // CookieName セッションクッキー名 + CookieName = "r_session" + sessionMaxAge = 60 * 60 * 24 * 14 // 2 weeks + sessionKeepAge = 60 * 60 * 24 * 14 // 2 weeks + cacheSize = 2048 +) + +var ErrSessionNotFound = errors.New("session not found") + +type Session interface { + Token() string + RefID() uuid.UUID + UserID() uuid.UUID + CreatedAt() time.Time + LoggedIn() bool + + Get(key string) (interface{}, error) + Set(key string, value interface{}) error + Delete(key string) error + + Expired() bool + Refreshable() bool +} + +type Store interface { + GetSession(c echo.Context, createIfNotExist bool) (Session, error) + GetSessionByToken(token string) (Session, error) + GetSessionsByUserID(userID uuid.UUID) ([]Session, error) + RevokeSession(c echo.Context) error + RevokeSessionByRefID(refID uuid.UUID) error + RevokeSessionsByUserID(userID uuid.UUID) error + RenewSession(c echo.Context, userID uuid.UUID) (Session, error) + IssueSession(userID uuid.UUID, data map[string]interface{}) (Session, error) +} diff --git a/router/sessions/config.go b/router/sessions/config.go deleted file mode 100644 index df6d65c9d..000000000 --- a/router/sessions/config.go +++ /dev/null @@ -1,11 +0,0 @@ -package sessions - -const ( - // CookieName セッションクッキー名 - CookieName = "r_session" - tableName = "r_sessions" - cacheSize = 4096 - mutexSize = 1024 - sessionMaxAge = 60 * 60 * 24 * 14 // 2 weeks - sessionKeepAge = 60 * 60 * 24 * 14 // 2 weeks -) diff --git a/router/sessions/session.go b/router/sessions/session.go deleted file mode 100644 index 3becb1b00..000000000 --- a/router/sessions/session.go +++ /dev/null @@ -1,336 +0,0 @@ -package sessions - -import ( - "encoding/gob" - "github.com/gofrs/uuid" - "github.com/labstack/echo/v4" - "github.com/traPtitech/traQ/utils" - "github.com/traPtitech/traQ/utils/random" - "net" - "net/http" - "strings" - "sync" - "time" -) - -var mutexes *utils.KeyMutex - -// Session セッション構造体 -type Session struct { - token string - referenceID uuid.UUID - userID uuid.UUID - created time.Time - lastAccess time.Time - lastIP string - lastUserAgent string - data map[string]interface{} - sync.RWMutex -} - -func init() { - gob.Register(map[string]interface{}{}) - mutexes = utils.NewKeyMutex(mutexSize) -} - -// Get セッションを取得します -func Get(rw http.ResponseWriter, req *http.Request, createIfNotExists bool) (*Session, error) { - userAgent := req.Header.Get("User-Agent") - ip := realIP(req) - - var token string - cookie, err := req.Cookie(CookieName) - if err == nil { - token = cookie.Value - } - - var session *Session - if len(token) > 0 { - mutexes.Lock(token) - defer mutexes.Unlock(token) - - var err error - session, err = store.GetByToken(token) - if err != nil { - return nil, err - } - if session == nil { - deleteCookie(cookie, rw) - } - } - - if session != nil { - session.RLock() - age := time.Since(session.created) - absent := time.Since(session.lastAccess) - session.RUnlock() - - valid := age <= time.Duration(sessionMaxAge)*time.Second - regenerate := absent <= time.Duration(sessionKeepAge)*time.Second - - if !valid { - - if regenerate { - // 最終アクセスからsessionKeepAge経過していない場合はセッションを継続 - - uid := session.GetUserID() - err := session.Destroy(rw, req) - if err != nil { - return nil, err - } - session, err := IssueNewSession(ip, userAgent) - if err != nil { - return nil, err - } - err = session.SetUser(uid) - if err != nil { - return nil, err - } - setCookie(session.token, rw) - return session, nil - } - - if err := session.Destroy(rw, req); err != nil { - return nil, err - } - - session = nil - - } else { - session.Lock() - session.lastAccess = time.Now() - session.lastUserAgent = userAgent - session.lastIP = ip - session.Unlock() - - return session, nil - } - } - - if !createIfNotExists { - return nil, nil - } - - session, err = IssueNewSession(ip, userAgent) - if err != nil { - return nil, nil - } - setCookie(session.token, rw) - - return session, nil -} - -// IssueNewSession 新しいセッションを生成します -func IssueNewSession(ip string, userAgent string) (s *Session, err error) { - session := &Session{ - token: random.SecureAlphaNumeric(50), - referenceID: uuid.Must(uuid.NewV4()), - userID: uuid.Nil, - created: time.Now(), - lastAccess: time.Now(), - lastUserAgent: userAgent, - lastIP: ip, - data: make(map[string]interface{}), - } - if err := store.Save(session.token, session); err != nil { - return nil, err - } - return session, nil -} - -// GetByToken 指定したtokenのセッションを取得します -func GetByToken(token string) (s *Session, err error) { - mutexes.Lock(token) - defer mutexes.Unlock(token) - - s, err = store.GetByToken(token) - if err != nil { - return nil, err - } - - if s != nil { - if s.Expired() { - if err := DestroyByToken(token); err != nil { - return nil, err - } - s = nil - } - } - - return s, nil -} - -// GetByUserID 指定したユーザーのセッションを全て取得します -func GetByUserID(id uuid.UUID) ([]*Session, error) { - sessions, err := store.GetByUserID(id) - if err != nil { - return nil, err - } - - var result []*Session - for _, v := range sessions { - mutexes.Lock(v.token) - if v.Expired() { - _ = DestroyByToken(v.token) - } else { - result = append(result, v) - } - mutexes.Unlock(v.token) - } - - return result, nil -} - -// DestroyByToken 指定したtokenのセッションを破棄します -func DestroyByToken(token string) error { - return store.DestroyByToken(token) -} - -// DestroyByUserID 指定したユーザーのセッションを全て破棄します -func DestroyByUserID(id uuid.UUID) error { - sessions, err := store.GetByUserID(id) - if err != nil { - return err - } - - for _, v := range sessions { - mutexes.Lock(v.token) - if err := DestroyByToken(v.token); err != nil { - mutexes.Unlock(v.token) - return err - } - mutexes.Unlock(v.token) - } - - return nil -} - -// DestroyByReferenceID 指定したユーザーのreferenceIDのセッションを破棄します -func DestroyByReferenceID(userID, referenceID uuid.UUID) error { - session, err := store.GetByReferenceID(referenceID) - if err != nil { - return err - } - if session.userID != userID { - return nil - } - mutexes.Lock(session.token) - defer mutexes.Unlock(session.token) - - return DestroyByToken(session.token) -} - -// Destroy セッションを破棄します -func (s *Session) Destroy(rw http.ResponseWriter, req *http.Request) error { - if err := DestroyByToken(s.token); err != nil { - return err - } - - cookie, err := req.Cookie(CookieName) - if err != nil { - return err - } - deleteCookie(cookie, rw) - - return nil -} - -// GetToken セッショントークンを返します -func (s *Session) GetToken() string { - return s.token -} - -// GetUserID セッションに紐づけられているユーザーのIDを返します -func (s *Session) GetUserID() uuid.UUID { - s.RLock() - defer s.RUnlock() - return s.userID -} - -// GetSessionInfo セッションの情報を返します -func (s *Session) GetSessionInfo() (referenceID uuid.UUID, created, lastAccess time.Time, lastIP, lastUserAgent string) { - s.RLock() - defer s.RUnlock() - return s.referenceID, s.created, s.lastAccess, s.lastIP, s.lastUserAgent -} - -// SetUser セッションにユーザーを紐づけます -func (s *Session) SetUser(userID uuid.UUID) error { - s.Lock() - s.userID = userID - s.Unlock() - return store.Save(s.token, s) -} - -// Get セッションから値を取り出します -func (s *Session) Get(key string) interface{} { - s.RLock() - defer s.RUnlock() - value, ok := s.data[key] - if ok { - return value - } - return nil -} - -// Set セッションに値をセットします -func (s *Session) Set(key string, value interface{}) error { - s.Lock() - s.data[key] = value - s.Unlock() - return store.Save(s.token, s) -} - -// Delete セッションから値を削除します -func (s *Session) Delete(key string) error { - s.Lock() - delete(s.data, key) - s.Unlock() - return store.Save(s.token, s) -} - -// Expired セッションの有効期限が切れているかどうか -func (s *Session) Expired() bool { - s.RLock() - age := time.Since(s.created) - s.RUnlock() - return age > time.Duration(sessionMaxAge+sessionKeepAge)*time.Second -} - -func deleteCookie(cookie *http.Cookie, rw http.ResponseWriter) { - deleted := *cookie - deleted.Value = "" - deleted.Expires = time.Unix(0, 0) - deleted.MaxAge = -1 - http.SetCookie(rw, &deleted) -} - -func setCookie(token string, rw http.ResponseWriter) { - cookie := &http.Cookie{ - Name: CookieName, - Value: token, - Expires: time.Now().Add(time.Duration(sessionMaxAge+sessionKeepAge) * time.Second), - MaxAge: sessionMaxAge + sessionKeepAge, - Path: "/", - HttpOnly: true, - } - http.SetCookie(rw, cookie) -} - -// PurgeCache キャッシュを全て解放し、その内容を永続化します -func PurgeCache() { - if s, ok := store.(CacheableStore); ok { - s.PurgeCache() - } -} - -func realIP(req *http.Request) string { - if ip := req.Header.Get(echo.HeaderXForwardedFor); ip != "" { - return strings.Split(ip, ", ")[0] - } - if ip := req.Header.Get(echo.HeaderXRealIP); ip != "" { - return ip - } - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra -} diff --git a/router/sessions/store.go b/router/sessions/store.go deleted file mode 100644 index 63f30d7c9..000000000 --- a/router/sessions/store.go +++ /dev/null @@ -1,278 +0,0 @@ -package sessions - -import ( - "bytes" - "encoding/gob" - "github.com/gofrs/uuid" - "github.com/hashicorp/golang-lru" - "github.com/jinzhu/gorm" - "sync" - "time" -) - -var store Store - -func init() { - store = NewInMemoryStore() -} - -// Store セッションストア -type Store interface { - GetByToken(token string) (*Session, error) - GetByUserID(id uuid.UUID) ([]*Session, error) - GetByReferenceID(id uuid.UUID) (*Session, error) - DestroyByToken(token string) error - Save(token string, session *Session) error -} - -// CacheableStore キャッシュ可能なセッションストア -type CacheableStore interface { - Store - PurgeCache() -} - -// SetStore ストアをセットします -func SetStore(s Store) { - store = s -} - -// InMemoryStore for test use -type InMemoryStore struct { - sync.RWMutex - sessions map[string]*Session -} - -// NewInMemoryStore インメモリストアを作成します -func NewInMemoryStore() Store { - return &InMemoryStore{ - sessions: make(map[string]*Session), - } -} - -// GetByToken gets token's session -func (s *InMemoryStore) GetByToken(token string) (*Session, error) { - s.RLock() - defer s.RUnlock() - return s.sessions[token], nil -} - -// GetByUserID gets the user's sessions -func (s *InMemoryStore) GetByUserID(id uuid.UUID) (result []*Session, err error) { - s.RLock() - defer s.RUnlock() - for _, v := range s.sessions { - v.RLock() - if v.userID == id { - result = append(result, v) - } - v.RUnlock() - } - return -} - -// GetByReferenceID gets id's session -func (s *InMemoryStore) GetByReferenceID(id uuid.UUID) (*Session, error) { - s.RLock() - defer s.RUnlock() - for _, v := range s.sessions { - if v.referenceID == id { - return v, nil - } - } - return nil, nil -} - -// DestroyByToken deletes token's session -func (s *InMemoryStore) DestroyByToken(token string) error { - s.Lock() - defer s.Unlock() - delete(s.sessions, token) - return nil -} - -// Save saves token's session -func (s *InMemoryStore) Save(token string, session *Session) error { - s.Lock() - defer s.Unlock() - s.sessions[token] = session - return nil -} - -// GORMStore GORMストア -type GORMStore struct { - sync.Mutex - db *gorm.DB - cache *lru.Cache -} - -// GetByToken gets token's session -func (s *GORMStore) GetByToken(token string) (*Session, error) { - s.Lock() - defer s.Unlock() - - if session, ok := s.cache.Get(token); ok { - return session.(*Session), nil - } - - var record SessionRecord - if err := s.db.First(&record, &SessionRecord{Token: token}).Error; err != nil { - if gorm.IsRecordNotFoundError(err) { - return nil, nil - } - return nil, err - } - session, err := record.decode() - if err != nil { - return nil, err - } - - s.cache.Add(token, session) - return session, nil -} - -// GetByUserID gets the user's sessions -func (s *GORMStore) GetByUserID(id uuid.UUID) ([]*Session, error) { - var records []*SessionRecord - if err := s.db.Where(&SessionRecord{UserID: id.String()}).Find(&records).Error; err != nil { - return nil, err - } - result := make([]*Session, len(records)) - - s.Lock() - defer s.Unlock() - for k, v := range records { - if sess, ok := s.cache.Get(v.Token); ok { - result[k] = sess.(*Session) - } else { - sess, err := v.decode() - if err == nil { - result[k] = sess - } else { - return nil, err - } - } - } - return result, nil -} - -// GetByReferenceID gets id's session -func (s *GORMStore) GetByReferenceID(id uuid.UUID) (*Session, error) { - var record SessionRecord - if err := s.db.First(&record, &SessionRecord{ReferenceID: id.String()}).Error; err != nil { - if gorm.IsRecordNotFoundError(err) { - return nil, nil - } - return nil, err - } - - s.Lock() - defer s.Unlock() - if sess, ok := s.cache.Get(record.Token); ok { - return sess.(*Session), nil - } - return record.decode() -} - -// DestroyByToken deletes token's session -func (s *GORMStore) DestroyByToken(token string) error { - s.Lock() - defer s.Unlock() - s.cache.Remove(token) - return s.db.Delete(&SessionRecord{Token: token}).Error -} - -// Save saves token's session -func (s *GORMStore) Save(token string, session *Session) error { - session.Lock() - session.lastAccess = time.Now() - session.Unlock() - - s.Lock() - defer s.Unlock() - s.cache.Add(token, session) - - sr := &SessionRecord{} - sr.encode(session) - return s.db. - Set("gorm:insert_option", "ON DUPLICATE KEY UPDATE user_id = VALUES(user_id), last_access = VALUES(last_access), last_ip = VALUES(last_ip), last_user_agent = VALUES(last_user_agent), data = VALUES(data)"). - Create(sr). - Error -} - -// PurgeCache キャッシュを全て永続化してから解放します -func (s *GORMStore) PurgeCache() { - s.Lock() - defer s.Unlock() - s.cache.Purge() -} - -// SessionRecord GORM用Session構造体 -type SessionRecord struct { - Token string `gorm:"type:varchar(50);primary_key"` - ReferenceID string `gorm:"type:char(36);unique"` - UserID string `gorm:"type:varchar(36);index"` - LastAccess time.Time `gorm:"precision:6"` - LastIP string `gorm:"type:text"` - LastUserAgent string `gorm:"type:text"` - Data []byte `gorm:"type:longblob"` - Created time.Time `gorm:"precision:6"` -} - -// TableName SessionRecordのテーブル名 -func (*SessionRecord) TableName() string { - return tableName -} - -func (sr *SessionRecord) encode(session *Session) { - session.RLock() - defer session.RUnlock() - - sr.Token = session.token - sr.ReferenceID = session.referenceID.String() - sr.UserID = session.userID.String() - sr.LastAccess = session.lastAccess - sr.LastIP = session.lastIP - sr.LastUserAgent = session.lastUserAgent - sr.Created = session.created - - buffer := bytes.Buffer{} - if err := gob.NewEncoder(&buffer).Encode(session.data); err != nil { - panic(err) // gobにdataの中身の構造体が登録されていない - } - sr.Data = buffer.Bytes() -} - -func (sr *SessionRecord) decode() (*Session, error) { - s := &Session{ - token: sr.Token, - referenceID: uuid.Must(uuid.FromString(sr.ReferenceID)), - userID: uuid.FromStringOrNil(sr.UserID), - lastAccess: sr.LastAccess, - lastIP: sr.LastIP, - lastUserAgent: sr.LastUserAgent, - created: sr.Created, - } - - if err := gob.NewDecoder(bytes.NewReader(sr.Data)).Decode(&s.data); err != nil { - return nil, err - } - - return s, nil -} - -// NewGORMStore GORMストアを作成します -func NewGORMStore(db *gorm.DB) (Store, error) { - cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { - sess := value.(*Session) - if !sess.Expired() { - sr := &SessionRecord{} - sr.encode(sess) - db.Save(sr) - } - }) - if err != nil { - return nil, err - } - - return &GORMStore{db: db, cache: cache}, nil -} diff --git a/router/utils/common.go b/router/utils/common.go index c608051c3..1e2babf40 100644 --- a/router/utils/common.go +++ b/router/utils/common.go @@ -8,7 +8,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" imaging2 "github.com/traPtitech/traQ/service/imaging" "github.com/traPtitech/traQ/utils/optional" "net/http" @@ -55,13 +55,13 @@ func ServeUserIcon(c echo.Context, repo repository.Repository, user model.UserIn } // ChangeUserPassword userIDのユーザーのパスワードを変更する -func ChangeUserPassword(c echo.Context, repo repository.Repository, userID uuid.UUID, newPassword string) error { +func ChangeUserPassword(c echo.Context, repo repository.Repository, seStore session.Store, userID uuid.UUID, newPassword string) error { if err := repo.UpdateUser(userID, repository.UpdateUserArgs{Password: optional.StringFrom(newPassword)}); err != nil { return herror.InternalServerError(err) } // ユーザーの全セッションを破棄(強制ログアウト) - _ = sessions.DestroyByUserID(userID) + _ = seStore.RevokeSessionsByUserID(userID) return c.NoContent(http.StatusNoContent) } diff --git a/router/v1/channels_test.go b/router/v1/channels_test.go index 3d4409adb..46688ac59 100644 --- a/router/v1/channels_test.go +++ b/router/v1/channels_test.go @@ -5,26 +5,26 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/utils/optional" - random2 "github.com/traPtitech/traQ/utils/random" + "github.com/traPtitech/traQ/utils/random" "net/http" "testing" ) func TestHandlers_GetChannels(t *testing.T) { t.Parallel() - repo, server, _, require, session, adminSession := setup(t, s1) + env, _, require, s, adminSession := setup(t, s1) for i := 0; i < 5; i++ { - c := mustMakeChannel(t, repo, rand) - _, err := repo.CreatePublicChannel(random2.AlphaNumeric(20), c.ID, uuid.Nil) + c := env.mustMakeChannel(t, rand) + _, err := env.Repository.CreatePublicChannel(random.AlphaNumeric(20), c.ID, uuid.Nil) require.NoError(err) } t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels"). Expect(). Status(http.StatusUnauthorized) @@ -32,9 +32,9 @@ func TestHandlers_GetChannels(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) arr := e.GET("/api/1.0/channels"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -44,9 +44,9 @@ func TestHandlers_GetChannels(t *testing.T) { t.Run("Successful2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) arr := e.GET("/api/1.0/channels"). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). Expect(). Status(http.StatusOK). JSON(). @@ -57,11 +57,11 @@ func TestHandlers_GetChannels(t *testing.T) { func TestHandlers_PostChannels(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common1) + env, _, _, s, _ := setup(t, common1) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels"). WithJSON(&PostChannelRequest{Name: "forbidden", Parent: uuid.Nil}). Expect(). @@ -70,20 +70,20 @@ func TestHandlers_PostChannels(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) - cname1 := random2.AlphaNumeric(20) + cname1 := random.AlphaNumeric(20) obj := e.POST("/api/1.0/channels"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(&PostChannelRequest{Name: cname1, Parent: uuid.Nil}). Expect(). Status(http.StatusCreated). @@ -102,9 +102,9 @@ func TestHandlers_PostChannels(t *testing.T) { c1, err := uuid.FromString(obj.Value("channelId").String().Raw()) require.NoError(t, err) - cname2 := random2.AlphaNumeric(20) + cname2 := random.AlphaNumeric(20) obj = e.POST("/api/1.0/channels"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(&PostChannelRequest{Name: cname2, Parent: c1}). Expect(). Status(http.StatusCreated). @@ -120,20 +120,20 @@ func TestHandlers_PostChannels(t *testing.T) { obj.Value("dm").Boolean().False() obj.Value("member").Array().Empty() - _, err = repo.GetChannel(uuid.FromStringOrNil(obj.Value("channelId").String().Raw())) + _, err = env.Repository.GetChannel(uuid.FromStringOrNil(obj.Value("channelId").String().Raw())) require.NoError(t, err) }) } func TestHandlers_PostChannelChildren(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common1) + env, _, _, s, _ := setup(t, common1) - pubCh := mustMakeChannel(t, repo, rand) + pubCh := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels/{channelID}/children", pubCh.ID.String()). WithJSON(map[string]string{"name": "forbidden"}). Expect(). @@ -142,9 +142,9 @@ func TestHandlers_PostChannelChildren(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels/{channelID}/children", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": "アイウエオ"}). Expect(). Status(http.StatusBadRequest) @@ -152,11 +152,11 @@ func TestHandlers_PostChannelChildren(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) - cname1 := random2.AlphaNumeric(20) + cname1 := random.AlphaNumeric(20) obj := e.POST("/api/1.0/channels/{channelID}/children", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"name": cname1}). Expect(). Status(http.StatusCreated). @@ -172,20 +172,20 @@ func TestHandlers_PostChannelChildren(t *testing.T) { obj.Value("dm").Boolean().False() obj.Value("member").Array().Empty() - _, err := repo.GetChannel(uuid.FromStringOrNil(obj.Value("channelId").String().Raw())) + _, err := env.Repository.GetChannel(uuid.FromStringOrNil(obj.Value("channelId").String().Raw())) require.NoError(t, err) }) } func TestHandlers_GetChannelByChannelID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common1) + env, _, _, s, _ := setup(t, common1) - pubCh := mustMakeChannel(t, repo, rand) + pubCh := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}", pubCh.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -193,18 +193,18 @@ func TestHandlers_GetChannelByChannelID(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) obj := e.GET("/api/1.0/channels/{channelID}", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -223,13 +223,13 @@ func TestHandlers_GetChannelByChannelID(t *testing.T) { func TestHandlers_PatchChannelByChannelID(t *testing.T) { t.Parallel() - repo, server, _, _, session, adminSession := setup(t, common1) + env, _, _, s, adminSession := setup(t, common1) - pubCh := mustMakeChannel(t, repo, rand) + pubCh := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/channels/{channelID}", pubCh.ID.String()). WithJSON(map[string]interface{}{"name": "renamed", "visibility": true}). Expect(). @@ -238,9 +238,9 @@ func TestHandlers_PatchChannelByChannelID(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/channels/{channelID}", pubCh.ID.String()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]interface{}{"name": true, "visibility": false, "force": true}). Expect(). Status(http.StatusBadRequest) @@ -248,17 +248,17 @@ func TestHandlers_PatchChannelByChannelID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) assert, require := assertAndRequire(t) - newName := random2.AlphaNumeric(20) + newName := random.AlphaNumeric(20) e.PATCH("/api/1.0/channels/{channelID}", pubCh.ID.String()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]interface{}{"name": newName, "visibility": false, "force": true}). Expect(). Status(http.StatusNoContent) - ch, err := repo.GetChannel(pubCh.ID) + ch, err := env.Repository.GetChannel(pubCh.ID) require.NoError(err) assert.Equal(newName, ch.Name) assert.False(ch.IsVisible) @@ -268,11 +268,11 @@ func TestHandlers_PatchChannelByChannelID(t *testing.T) { // 権限がない t.Run("Failure1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) - newName := random2.AlphaNumeric(20) + newName := random.AlphaNumeric(20) e.PATCH("/api/1.0/channels/{channelID}", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": newName, "visibility": false, "force": true}). Expect(). Status(http.StatusForbidden) @@ -281,14 +281,14 @@ func TestHandlers_PatchChannelByChannelID(t *testing.T) { func TestHandlers_PutChannelParent(t *testing.T) { t.Parallel() - repo, server, _, _, session, adminSession := setup(t, common1) + env, _, _, s, adminSession := setup(t, common1) - pCh := mustMakeChannel(t, repo, rand) - cCh := mustMakeChannel(t, repo, rand) + pCh := env.mustMakeChannel(t, rand) + cCh := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/parent", cCh.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -296,14 +296,14 @@ func TestHandlers_PutChannelParent(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/parent", cCh.ID.String()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]string{"parent": pCh.ID.String()}). Expect(). Status(http.StatusNoContent) - ch, err := repo.GetChannel(cCh.ID) + ch, err := env.Repository.GetChannel(cCh.ID) require.NoError(t, err) assert.Equal(t, ch.ParentID, pCh.ID) }) @@ -311,9 +311,9 @@ func TestHandlers_PutChannelParent(t *testing.T) { // 権限がない t.Run("Failure1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/parent", cCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"parent": pCh.ID.String()}). Expect(). Status(http.StatusForbidden) @@ -322,18 +322,18 @@ func TestHandlers_PutChannelParent(t *testing.T) { func TestHandlers_GetTopic(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common1) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common1) - pubCh := mustMakeChannel(t, repo, rand) + pubCh := env.mustMakeChannel(t, rand) topicText := "Topic test" - require.NoError(t, repo.UpdateChannel(pubCh.ID, repository.UpdateChannelArgs{ + require.NoError(t, env.Repository.UpdateChannel(pubCh.ID, repository.UpdateChannelArgs{ UpdaterID: testUser.GetID(), Topic: optional.StringFrom(topicText), })) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/topic", pubCh.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -341,9 +341,9 @@ func TestHandlers_GetTopic(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/topic", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -356,11 +356,11 @@ func TestHandlers_GetTopic(t *testing.T) { func TestHandlers_PutTopic(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common1) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common1) - pubCh := mustMakeChannel(t, repo, rand) + pubCh := env.mustMakeChannel(t, rand) topicText := "Topic test" - require.NoError(t, repo.UpdateChannel(pubCh.ID, repository.UpdateChannelArgs{ + require.NoError(t, env.Repository.UpdateChannel(pubCh.ID, repository.UpdateChannelArgs{ UpdaterID: testUser.GetID(), Topic: optional.StringFrom(topicText), })) @@ -368,7 +368,7 @@ func TestHandlers_PutTopic(t *testing.T) { t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/topic", pubCh.ID.String()). WithJSON(map[string]string{"text": newTopic}). Expect(). @@ -377,9 +377,9 @@ func TestHandlers_PutTopic(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/topic", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"text": true}). Expect(). Status(http.StatusBadRequest) @@ -387,14 +387,14 @@ func TestHandlers_PutTopic(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/topic", pubCh.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"text": newTopic}). Expect(). Status(http.StatusNoContent) - ch, err := repo.GetChannel(pubCh.ID) + ch, err := env.Repository.GetChannel(pubCh.ID) require.NoError(t, err) assert.Equal(t, newTopic, ch.Topic) }) diff --git a/router/v1/files_test.go b/router/v1/files_test.go index 5ecd0ffee..3d2f6eb50 100644 --- a/router/v1/files_test.go +++ b/router/v1/files_test.go @@ -6,7 +6,7 @@ import ( "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/utils/optional" "net/http" "strings" @@ -17,12 +17,12 @@ import ( func TestHandlers_GetFileByID(t *testing.T) { t.Parallel() - repo, server, _, require, session, _ := setup(t, common1) + env, _, require, s, _ := setup(t, common1) - file := mustMakeFile(t, repo) - grantedUser := mustMakeUser(t, repo, rand) + file := env.mustMakeFile(t) + grantedUser := env.mustMakeUser(t, rand) secureContent := "secure" - secureFile, err := repo.SaveFile(repository.SaveFileArgs{ + secureFile, err := env.Repository.SaveFile(repository.SaveFileArgs{ FileName: "secure", FileSize: int64(len(secureContent)), MimeType: "text/plain", @@ -36,7 +36,7 @@ func TestHandlers_GetFileByID(t *testing.T) { t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}", file.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -44,27 +44,27 @@ func TestHandlers_GetFileByID(t *testing.T) { t.Run("Not Found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Not Accessible", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}", secureFile.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusForbidden) }) t.Run("Success", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}", file.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). Body(). @@ -73,9 +73,9 @@ func TestHandlers_GetFileByID(t *testing.T) { t.Run("Success with dl param", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) res := e.GET("/api/1.0/files/{fileID}", file.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithQuery("dl", 1). Expect(). Status(http.StatusOK) @@ -86,14 +86,14 @@ func TestHandlers_GetFileByID(t *testing.T) { t.Run("Success with icon file", func(t *testing.T) { t.Parallel() - iconFileID, err := repository.GenerateIconFile(repo, "test") + iconFileID, err := repository.GenerateIconFile(env.Repository, "test") require.NoError(err) - iconFile, err := repo.GetFileMeta(iconFileID) + iconFile, err := env.Repository.GetFileMeta(iconFileID) require.NoError(err) - e := makeExp(t, server) + e := env.makeExp(t) res := e.GET("/api/1.0/files/{fileID}", iconFile.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK) res.ContentType(iconFile.GetMIMEType()) @@ -103,9 +103,9 @@ func TestHandlers_GetFileByID(t *testing.T) { t.Run("Success With secure file", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}", secureFile.GetID()). - WithCookie(sessions.CookieName, generateSession(t, grantedUser.GetID())). + WithCookie(session.CookieName, env.generateSession(t, grantedUser.GetID())). Expect(). Status(http.StatusOK). Body(). @@ -115,13 +115,13 @@ func TestHandlers_GetFileByID(t *testing.T) { func TestHandlers_GetMetaDataByFileID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common1) + env, _, _, s, _ := setup(t, common1) - file := mustMakeFile(t, repo) + file := env.mustMakeFile(t) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}/meta", file.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -129,9 +129,9 @@ func TestHandlers_GetMetaDataByFileID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) obj := e.GET("/api/1.0/files/{fileID}/meta", file.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -148,12 +148,12 @@ func TestHandlers_GetMetaDataByFileID(t *testing.T) { func TestHandlers_GetThumbnailByID(t *testing.T) { t.Parallel() - repo, server, _, require, session, _ := setup(t, common1) + env, _, require, s, _ := setup(t, common1) - file := mustMakeFile(t, repo) - grantedUser := mustMakeUser(t, repo, rand) + file := env.mustMakeFile(t) + grantedUser := env.mustMakeUser(t, rand) secureContent := "secure" - secureFile, err := repo.SaveFile(repository.SaveFileArgs{ + secureFile, err := env.Repository.SaveFile(repository.SaveFileArgs{ FileName: "secure", FileSize: int64(len(secureContent)), MimeType: "text/plain", @@ -167,7 +167,7 @@ func TestHandlers_GetThumbnailByID(t *testing.T) { t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}/thumbnail", file.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -175,39 +175,39 @@ func TestHandlers_GetThumbnailByID(t *testing.T) { t.Run("Not Found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}/thumbnail", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Not Accessible", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}/thumbnail", secureFile.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusForbidden) }) t.Run("No Thumbnail", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/files/{fileID}/thumbnail", file.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Success", func(t *testing.T) { t.Parallel() - iconFileID, err := repository.GenerateIconFile(repo, "test") + iconFileID, err := repository.GenerateIconFile(env.Repository, "test") require.NoError(err) - e := makeExp(t, server) + e := env.makeExp(t) res := e.GET("/api/1.0/files/{fileID}/thumbnail", iconFileID). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK) res.Header(consts.HeaderCacheControl).Equal("private, max-age=31536000") diff --git a/router/v1/messages_test.go b/router/v1/messages_test.go index 6f46cd316..e60151432 100644 --- a/router/v1/messages_test.go +++ b/router/v1/messages_test.go @@ -4,21 +4,21 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "net/http" "testing" ) func TestHandlers_GetMessageByID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) + channel := env.mustMakeChannel(t, rand) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/messages/{messageID}", message.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -26,9 +26,9 @@ func TestHandlers_GetMessageByID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) obj := e.GET("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -48,13 +48,13 @@ func TestHandlers_GetMessageByID(t *testing.T) { func TestHandlers_PostMessage(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels/{channelID}/messages", channel.ID.String()). WithJSON(map[string]string{"text": "test message"}). Expect(). @@ -63,11 +63,11 @@ func TestHandlers_PostMessage(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) message := "test message" obj := e.POST("/api/1.0/channels/{channelID}/messages", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"text": message}). Expect(). Status(http.StatusCreated). @@ -84,15 +84,15 @@ func TestHandlers_PostMessage(t *testing.T) { obj.Value("updatedAt").String().NotEmpty() obj.Value("stampList").Array().Empty() - _, err := repo.GetMessageByID(uuid.FromStringOrNil(obj.Value("messageId").String().Raw())) + _, err := env.Repository.GetMessageByID(uuid.FromStringOrNil(obj.Value("messageId").String().Raw())) assert.NoError(t, err) }) t.Run("Failure2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/channels/{channelID}/messages", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"not_text_field": "not_text_field"}). Expect(). Status(http.StatusBadRequest) @@ -101,17 +101,17 @@ func TestHandlers_PostMessage(t *testing.T) { func TestHandlers_GetMessagesByChannelID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) for i := 0; i < 5; i++ { - mustMakeMessage(t, repo, testUser.GetID(), channel.ID) + env.mustMakeMessage(t, testUser.GetID(), channel.ID) } t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/messages", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -119,9 +119,9 @@ func TestHandlers_GetMessagesByChannelID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/messages", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -132,11 +132,11 @@ func TestHandlers_GetMessagesByChannelID(t *testing.T) { t.Run("Successful2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/messages", channel.ID.String()). WithQuery("limit", 3). WithQuery("offset", 1). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -148,15 +148,15 @@ func TestHandlers_GetMessagesByChannelID(t *testing.T) { func TestHandlers_PutMessageByID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) - postmanID := mustMakeUser(t, repo, rand).GetID() + channel := env.mustMakeChannel(t, rand) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) + postmanID := env.mustMakeUser(t, rand).GetID() t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/messages/{messageID}", message.ID.String()). WithJSON(map[string]string{"text": "new message"}). Expect(). @@ -165,24 +165,24 @@ func TestHandlers_PutMessageByID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) text := "new message" e.PUT("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"text": text}). Expect(). Status(http.StatusNoContent) - m, err := repo.GetMessageByID(message.ID) + m, err := env.Repository.GetMessageByID(message.ID) assert.NoError(t, err) assert.Equal(t, text, m.Text) }) t.Run("Failure2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, postmanID)). + WithCookie(session.CookieName, env.generateSession(t, postmanID)). WithJSON(map[string]string{"text": "new message"}). Expect(). Status(http.StatusForbidden) @@ -191,15 +191,15 @@ func TestHandlers_PutMessageByID(t *testing.T) { func TestHandlers_DeleteMessageByID(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) - postmanID := mustMakeUser(t, repo, rand).GetID() + channel := env.mustMakeChannel(t, rand) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) + postmanID := env.mustMakeUser(t, rand).GetID() t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/messages/{messageID}", message.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -207,49 +207,49 @@ func TestHandlers_DeleteMessageByID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - _, err := repo.GetMessageByID(message.ID) + _, err := env.Repository.GetMessageByID(message.ID) assert.Equal(t, repository.ErrNotFound, err) }) t.Run("Webhook Message", func(t *testing.T) { t.Parallel() - wb := mustMakeWebhook(t, repo, rand, channel.ID, testUser.GetID(), "") - message := mustMakeMessage(t, repo, wb.GetBotUserID(), channel.ID) + wb := env.mustMakeWebhook(t, rand, channel.ID, testUser.GetID(), "") + message := env.mustMakeMessage(t, wb.GetBotUserID(), channel.ID) - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - _, err := repo.GetMessageByID(message.ID) + _, err := env.Repository.GetMessageByID(message.ID) assert.Equal(t, repository.ErrNotFound, err) }) t.Run("Forbidden (other's message)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) + e := env.makeExp(t) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) e.DELETE("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, postmanID)). + WithCookie(session.CookieName, env.generateSession(t, postmanID)). Expect(). Status(http.StatusForbidden) }) t.Run("Forbidden (other's webhook message)", func(t *testing.T) { t.Parallel() - wb := mustMakeWebhook(t, repo, rand, channel.ID, testUser.GetID(), "") - message := mustMakeMessage(t, repo, wb.GetBotUserID(), channel.ID) + wb := env.mustMakeWebhook(t, rand, channel.ID, testUser.GetID(), "") + message := env.mustMakeMessage(t, wb.GetBotUserID(), channel.ID) - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/messages/{messageID}", message.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, postmanID)). + WithCookie(session.CookieName, env.generateSession(t, postmanID)). Expect(). Status(http.StatusForbidden) }) @@ -257,14 +257,14 @@ func TestHandlers_DeleteMessageByID(t *testing.T) { func TestHandlers_PostMessageReport(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) + channel := env.mustMakeChannel(t, rand) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/messages/{messageID}/report", message.ID.String()). WithJSON(map[string]string{"reason": "aaaa"}). Expect(). @@ -273,23 +273,23 @@ func TestHandlers_PostMessageReport(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/messages/{messageID}/report", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"reason": "aaaa"}). Expect(). Status(http.StatusNoContent) - r, err := repo.GetMessageReportsByMessageID(message.ID) + r, err := env.Repository.GetMessageReportsByMessageID(message.ID) assert.NoError(t, err) assert.Len(t, r, 1) }) t.Run("Failure1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/messages/{messageID}/report", message.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"not_reason": "aaaa"}). Expect(). Status(http.StatusBadRequest) @@ -298,15 +298,15 @@ func TestHandlers_PostMessageReport(t *testing.T) { func TestHandlers_DeleteUnread(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common2) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common2) - channel := mustMakeChannel(t, repo, rand) - message := mustMakeMessage(t, repo, testUser.GetID(), channel.ID) - mustMakeMessageUnread(t, repo, testUser.GetID(), message.ID) + channel := env.mustMakeChannel(t, rand) + message := env.mustMakeMessage(t, testUser.GetID(), channel.ID) + env.mustMakeMessageUnread(t, testUser.GetID(), message.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/me/unread/channels/{channelID}", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -314,9 +314,9 @@ func TestHandlers_DeleteUnread(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/me/unread/channels/{channelID}", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) }) diff --git a/router/v1/notification_test.go b/router/v1/notification_test.go index 5d98d565e..34dc8ea52 100644 --- a/router/v1/notification_test.go +++ b/router/v1/notification_test.go @@ -6,23 +6,23 @@ import ( "github.com/stretchr/testify/require" "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "net/http" "testing" ) func TestHandlers_PutNotificationStatus(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common2) + env, _, _, s, _ := setup(t, common2) - user := mustMakeUser(t, repo, rand) + user := env.mustMakeUser(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/notification", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -31,16 +31,16 @@ func TestHandlers_PutNotificationStatus(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/notification", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string][]string{"on": {user.GetID().String()}}). Expect(). Status(http.StatusNoContent) - subscriptions, err := repo.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) + subscriptions, err := env.Repository.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) require.NoError(t, err) users := make([]uuid.UUID, 0) for _, subscription := range subscriptions { @@ -53,16 +53,16 @@ func TestHandlers_PutNotificationStatus(t *testing.T) { t.Run("Successful2", func(t *testing.T) { t.Parallel() - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/notification", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string][]uuid.UUID{"on": {uuid.Must(uuid.NewV4()), user.GetID(), uuid.Must(uuid.NewV4())}, "off": {uuid.Must(uuid.NewV4())}}). Expect(). Status(http.StatusNoContent) - subscriptions, err := repo.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) + subscriptions, err := env.Repository.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) require.NoError(t, err) users := make([]uuid.UUID, 0) for _, subscription := range subscriptions { @@ -75,17 +75,17 @@ func TestHandlers_PutNotificationStatus(t *testing.T) { t.Run("Successful3", func(t *testing.T) { t.Parallel() - channel := mustMakeChannel(t, repo, rand) - mustChangeChannelSubscription(t, repo, channel.ID, user.GetID()) + channel := env.mustMakeChannel(t, rand) + env.mustChangeChannelSubscription(t, channel.ID, user.GetID()) - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/channels/{channelID}/notification", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string][]string{"off": {user.GetID().String()}}). Expect(). Status(http.StatusNoContent) - subscriptions, err := repo.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) + subscriptions, err := env.Repository.GetChannelSubscriptions(repository.ChannelSubscriptionQuery{}.SetChannel(channel.ID).SetLevel(model.ChannelSubscribeLevelMarkAndNotify)) require.NoError(t, err) assert.Len(t, subscriptions, 0) }) @@ -93,16 +93,16 @@ func TestHandlers_PutNotificationStatus(t *testing.T) { func TestHandlers_GetNotificationStatus(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common2) + env, _, _, s, _ := setup(t, common2) - channel := mustMakeChannel(t, repo, rand) - user := mustMakeUser(t, repo, rand) + channel := env.mustMakeChannel(t, rand) + user := env.mustMakeUser(t, rand) - mustChangeChannelSubscription(t, repo, channel.ID, user.GetID()) + env.mustChangeChannelSubscription(t, channel.ID, user.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/notification", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -110,9 +110,9 @@ func TestHandlers_GetNotificationStatus(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/channels/{channelID}/notification", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -124,15 +124,15 @@ func TestHandlers_GetNotificationStatus(t *testing.T) { func TestHandlers_GetNotificationChannels(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common2) + env, _, _, s, _ := setup(t, common2) - user := mustMakeUser(t, repo, rand) - mustChangeChannelSubscription(t, repo, mustMakeChannel(t, repo, rand).ID, user.GetID()) - mustChangeChannelSubscription(t, repo, mustMakeChannel(t, repo, rand).ID, user.GetID()) + user := env.mustMakeUser(t, rand) + env.mustChangeChannelSubscription(t, env.mustMakeChannel(t, rand).ID, user.GetID()) + env.mustChangeChannelSubscription(t, env.mustMakeChannel(t, rand).ID, user.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/notification", user.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -140,9 +140,9 @@ func TestHandlers_GetNotificationChannels(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/notification", user.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -154,14 +154,14 @@ func TestHandlers_GetNotificationChannels(t *testing.T) { func TestHandlers_GetMyNotificationChannels(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common2) + env, _, _, s, _, user, _ := setupWithUsers(t, common2) - mustChangeChannelSubscription(t, repo, mustMakeChannel(t, repo, rand).ID, user.GetID()) - mustChangeChannelSubscription(t, repo, mustMakeChannel(t, repo, rand).ID, user.GetID()) + env.mustChangeChannelSubscription(t, env.mustMakeChannel(t, rand).ID, user.GetID()) + env.mustChangeChannelSubscription(t, env.mustMakeChannel(t, rand).ID, user.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/notification"). Expect(). Status(http.StatusUnauthorized) @@ -169,9 +169,9 @@ func TestHandlers_GetMyNotificationChannels(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/notification"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). diff --git a/router/v1/public_test.go b/router/v1/public_test.go index f675d514b..fcf2e4238 100644 --- a/router/v1/public_test.go +++ b/router/v1/public_test.go @@ -11,11 +11,11 @@ import ( func TestHandlers_GetPublicUserIcon(t *testing.T) { t.Parallel() - repo, server, _, _, _, _, testUser, _ := setupWithUsers(t, common5) + env, _, _, _, _, testUser, _ := setupWithUsers(t, common5) t.Run("No name", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/icon/"). Expect(). Status(http.StatusNotFound) @@ -23,7 +23,7 @@ func TestHandlers_GetPublicUserIcon(t *testing.T) { t.Run("No user", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/icon/no+user"). Expect(). Status(http.StatusNotFound) @@ -33,9 +33,9 @@ func TestHandlers_GetPublicUserIcon(t *testing.T) { t.Parallel() _, require := assertAndRequire(t) - meta, err := repo.GetFileMeta(testUser.GetIconFileID()) + meta, err := env.Repository.GetFileMeta(testUser.GetIconFileID()) require.NoError(err) - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/icon/{username}", testUser.GetName()). Expect(). Status(http.StatusOK). @@ -47,10 +47,10 @@ func TestHandlers_GetPublicUserIcon(t *testing.T) { t.Parallel() _, require := assertAndRequire(t) - meta, err := repo.GetFileMeta(testUser.GetIconFileID()) + meta, err := env.Repository.GetFileMeta(testUser.GetIconFileID()) require.NoError(err) - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/icon/{username}", testUser.GetName()). WithHeader("If-None-Match", strconv.Quote(meta.GetMD5Hash())). Expect(). @@ -60,15 +60,15 @@ func TestHandlers_GetPublicUserIcon(t *testing.T) { func TestHandlers_GetPublicEmojiJSON(t *testing.T) { t.Parallel() - repo, server, _, _, _, _ := setup(t, s3) + env, _, _, _, _ := setup(t, s3) var stamps []interface{} for i := 0; i < 10; i++ { - s := mustMakeStamp(t, repo, rand, uuid.Nil) + s := env.mustMakeStamp(t, rand, uuid.Nil) stamps = append(stamps, s.Name) } - e := makeExp(t, server) + e := env.makeExp(t) res := e.GET("/api/1.0/public/emoji.json"). Expect(). Status(http.StatusOK) @@ -84,7 +84,7 @@ func TestHandlers_GetPublicEmojiJSON(t *testing.T) { t.Run("304", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji.json"). WithHeader(consts.HeaderIfModifiedSince, res.Header(echo.HeaderLastModified).Raw()). Expect(). @@ -93,7 +93,7 @@ func TestHandlers_GetPublicEmojiJSON(t *testing.T) { t.Run("Return cache", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji.json"). Expect(). Status(http.StatusOK) @@ -102,13 +102,13 @@ func TestHandlers_GetPublicEmojiJSON(t *testing.T) { func TestHandlers_GetPublicEmojiCSS(t *testing.T) { t.Parallel() - repo, server, _, _, _, _ := setup(t, s4) + env, _, _, _, _ := setup(t, s4) for i := 0; i < 10; i++ { - mustMakeStamp(t, repo, rand, uuid.Nil) + env.mustMakeStamp(t, rand, uuid.Nil) } - e := makeExp(t, server) + e := env.makeExp(t) res := e.GET("/api/1.0/public/emoji.css"). Expect(). Status(http.StatusOK) @@ -118,7 +118,7 @@ func TestHandlers_GetPublicEmojiCSS(t *testing.T) { t.Run("304", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji.css"). WithHeader(consts.HeaderIfModifiedSince, res.Header(echo.HeaderLastModified).Raw()). Expect(). @@ -127,7 +127,7 @@ func TestHandlers_GetPublicEmojiCSS(t *testing.T) { t.Run("Return cache", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji.css"). Expect(). Status(http.StatusOK) @@ -136,13 +136,13 @@ func TestHandlers_GetPublicEmojiCSS(t *testing.T) { func TestHandlers_GetPublicEmojiImage(t *testing.T) { t.Parallel() - repo, server, _, _, _, _ := setup(t, common5) + env, _, _, _, _ := setup(t, common5) - s := mustMakeStamp(t, repo, rand, uuid.Nil) + s := env.mustMakeStamp(t, rand, uuid.Nil) t.Run("Not Found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji/{stampID}", uuid.Must(uuid.NewV4())). Expect(). Status(http.StatusNotFound) @@ -150,7 +150,7 @@ func TestHandlers_GetPublicEmojiImage(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji/{stampID}", s.ID). Expect(). Status(http.StatusOK) @@ -160,10 +160,10 @@ func TestHandlers_GetPublicEmojiImage(t *testing.T) { t.Parallel() _, require := assertAndRequire(t) - meta, err := repo.GetFileMeta(s.FileID) + meta, err := env.Repository.GetFileMeta(s.FileID) require.NoError(err) - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/public/emoji/{stampID}", s.ID). WithHeader("If-None-Match", strconv.Quote(meta.GetMD5Hash())). Expect(). diff --git a/router/v1/router.go b/router/v1/router.go index 3c0bffcab..69b562b33 100644 --- a/router/v1/router.go +++ b/router/v1/router.go @@ -13,6 +13,7 @@ import ( "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" "github.com/traPtitech/traQ/router/middlewares" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/service/counter" "github.com/traPtitech/traQ/service/heartbeat" imaging2 "github.com/traPtitech/traQ/service/imaging" @@ -47,6 +48,7 @@ type Handlers struct { VM *viewer.Manager HeartBeats *heartbeat.Manager Imaging imaging2.Processor + SessStore session.Store emojiJSONCache bytes.Buffer `wire:"-"` emojiJSONTime time.Time `wire:"-"` @@ -63,7 +65,7 @@ func (h *Handlers) Setup(e *echo.Group) { bodyLimit := middlewares.RequestBodyLengthLimit retrieve := middlewares.NewParamRetriever(h.Repo) blockBot := middlewares.BlockBot(h.Repo) - nologin := middlewares.NoLogin() + nologin := middlewares.NoLogin(h.SessStore) requiresBotAccessPerm := middlewares.CheckBotAccessPerm(h.RBAC, h.Repo) requiresWebhookAccessPerm := middlewares.CheckWebhookAccessPerm(h.RBAC, h.Repo) @@ -74,7 +76,7 @@ func (h *Handlers) Setup(e *echo.Group) { gone := func(c echo.Context) error { return herror.HTTPError(http.StatusGone, "this api has been deleted") } - api := e.Group("/1.0", middlewares.UserAuthenticate(h.Repo)) + api := e.Group("/1.0", middlewares.UserAuthenticate(h.Repo, h.SessStore)) { apiUsers := api.Group("/users") { diff --git a/router/v1/router_test.go b/router/v1/router_test.go index 592d3de60..92a44d5f2 100644 --- a/router/v1/router_test.go +++ b/router/v1/router_test.go @@ -7,9 +7,9 @@ import ( "github.com/leandro-lugaresi/hub" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/service/counter" - imaging2 "github.com/traPtitech/traQ/service/imaging" + "github.com/traPtitech/traQ/service/imaging" "github.com/traPtitech/traQ/service/rbac" "github.com/traPtitech/traQ/service/rbac/permission" "github.com/traPtitech/traQ/service/viewer" @@ -45,10 +45,7 @@ const ( s4 = "s4" ) -var ( - servers = map[string]*httptest.Server{} - repositories = map[string]*TestRepository{} -) +var envs = map[string]*Env{} func TestMain(m *testing.M) { // setup server @@ -65,24 +62,27 @@ func TestMain(m *testing.M) { s4, } for _, key := range repos { - repo := NewTestRepository() + env := &Env{} + env.Repository = NewTestRepository() + env.Hub = hub.New() + env.SessStore = session.NewMemorySessionStore() + env.RBAC = newTestRBAC() e := echo.New() e.HideBanner = true e.HidePort = true e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) - e.Use(extension.Wrap(repo)) + e.Use(extension.Wrap(env.Repository)) - h := hub.New() handlers := &Handlers{ - RBAC: newTestRBAC(), - Repo: repo, - Hub: h, - Logger: zap.NewNop(), - OC: counter.NewOnlineCounter(h), - VM: viewer.NewManager(h), - HeartBeats: nil, - Imaging: imaging2.NewProcessor(imaging2.Config{ + RBAC: env.RBAC, + Repo: env.Repository, + Hub: env.Hub, + Logger: zap.NewNop(), + OC: counter.NewOnlineCounter(env.Hub), + VM: viewer.NewManager(env.Hub), + SessStore: env.SessStore, + Imaging: imaging.NewProcessor(imaging.Config{ MaxPixels: 1000 * 1000, Concurrency: 1, ThumbnailMaxSize: image.Pt(360, 480), @@ -90,69 +90,70 @@ func TestMain(m *testing.M) { }), } handlers.Setup(e.Group("/api")) - servers[key] = httptest.NewServer(e) - repositories[key] = repo + env.Server = httptest.NewServer(e) + envs[key] = env } code := m.Run() - for _, v := range servers { - v.Close() + for _, env := range envs { + env.Server.Close() + env.Hub.Close() } - os.Exit(code) } -func setup(t *testing.T, server string) (repository.Repository, *httptest.Server, *assert.Assertions, *require.Assertions, string, string) { +type Env struct { + Server *httptest.Server + Repository repository.Repository + Hub *hub.Hub + SessStore session.Store + RBAC rbac.RBAC +} + +func setup(t *testing.T, server string) (*Env, *assert.Assertions, *require.Assertions, string, string) { t.Helper() - s, ok := servers[server] + env, ok := envs[server] if !ok { t.FailNow() } assert, require := assertAndRequire(t) - repo := repositories[server] - testUser := mustMakeUser(t, repo, rand) + repo := env.Repository + testUser := env.mustMakeUser(t, rand) adminUser, err := repo.GetUserByName("traq", true) require.NoError(err) - return repo, s, assert, require, generateSession(t, testUser.GetID()), generateSession(t, adminUser.GetID()) + return env, assert, require, env.generateSession(t, testUser.GetID()), env.generateSession(t, adminUser.GetID()) } -func setupWithUsers(t *testing.T, server string) (repository.Repository, *httptest.Server, *assert.Assertions, *require.Assertions, string, string, model.UserInfo, model.UserInfo) { +func setupWithUsers(t *testing.T, server string) (*Env, *assert.Assertions, *require.Assertions, string, string, model.UserInfo, model.UserInfo) { t.Helper() - s, ok := servers[server] + env, ok := envs[server] if !ok { t.FailNow() } assert, require := assertAndRequire(t) - repo := repositories[server] - testUser := mustMakeUser(t, repo, rand) + repo := env.Repository + testUser := env.mustMakeUser(t, rand) adminUser, err := repo.GetUserByName("traq", true) require.NoError(err) - return repo, s, assert, require, generateSession(t, testUser.GetID()), generateSession(t, adminUser.GetID()), testUser, adminUser + return env, assert, require, env.generateSession(t, testUser.GetID()), env.generateSession(t, adminUser.GetID()), testUser, adminUser } func assertAndRequire(t *testing.T) (*assert.Assertions, *require.Assertions) { return assert.New(t), require.New(t) } -func generateSession(t *testing.T, userID uuid.UUID) string { +func (env *Env) generateSession(t *testing.T, userID uuid.UUID) string { t.Helper() - require := require.New(t) - req := httptest.NewRequest(echo.GET, "/", nil) - rec := httptest.NewRecorder() - - sess, err := sessions.Get(rec, req, true) - require.NoError(err) - require.NoError(sess.SetUser(userID)) - cookie := parseCookies(rec.Header().Get("Set-Cookie"))[sessions.CookieName] - - return cookie.Value + sess, err := env.SessStore.IssueSession(userID, nil) + require.NoError(t, err) + return sess.Token() } -func makeExp(t *testing.T, server *httptest.Server) *httpexpect.Expect { +func (env *Env) makeExp(t *testing.T) *httpexpect.Expect { t.Helper() return httpexpect.WithConfig(httpexpect.Config{ - BaseURL: server.URL, + BaseURL: env.Server.URL, Reporter: httpexpect.NewAssertReporter(t), Printers: []httpexpect.Printer{ httpexpect.NewCurlPrinter(t), @@ -168,50 +169,42 @@ func makeExp(t *testing.T, server *httptest.Server) *httpexpect.Expect { }) } -func parseCookies(value string) map[string]*http.Cookie { - m := map[string]*http.Cookie{} - for _, c := range (&http.Request{Header: http.Header{"Cookie": {value}}}).Cookies() { - m[c.Name] = c - } - return m -} - -func mustMakeChannel(t *testing.T, repo repository.Repository, name string) *model.Channel { +func (env *Env) mustMakeChannel(t *testing.T, name string) *model.Channel { t.Helper() if name == rand { name = random.AlphaNumeric(20) } - ch, err := repo.CreatePublicChannel(name, uuid.Nil, uuid.Nil) + ch, err := env.Repository.CreatePublicChannel(name, uuid.Nil, uuid.Nil) require.NoError(t, err) return ch } -func mustMakeMessage(t *testing.T, repo repository.Repository, userID, channelID uuid.UUID) *model.Message { +func (env *Env) mustMakeMessage(t *testing.T, userID, channelID uuid.UUID) *model.Message { t.Helper() - m, err := repo.CreateMessage(userID, channelID, "popopo") + m, err := env.Repository.CreateMessage(userID, channelID, "popopo") require.NoError(t, err) return m } -func mustMakeMessageUnread(t *testing.T, repo repository.Repository, userID, messageID uuid.UUID) { +func (env *Env) mustMakeMessageUnread(t *testing.T, userID, messageID uuid.UUID) { t.Helper() - require.NoError(t, repo.SetMessageUnread(userID, messageID, false)) + require.NoError(t, env.Repository.SetMessageUnread(userID, messageID, false)) } -func mustMakeUser(t *testing.T, repo repository.Repository, userName string) model.UserInfo { +func (env *Env) mustMakeUser(t *testing.T, userName string) model.UserInfo { t.Helper() if userName == rand { userName = random.AlphaNumeric(32) } - u, err := repo.CreateUser(repository.CreateUserArgs{Name: userName, Password: "test", Role: role.User}) + u, err := env.Repository.CreateUser(repository.CreateUserArgs{Name: userName, Password: "test", Role: role.User}) require.NoError(t, err) return u } -func mustMakeFile(t *testing.T, repo repository.Repository) model.FileMeta { +func (env *Env) mustMakeFile(t *testing.T) model.FileMeta { t.Helper() buf := bytes.NewBufferString("test message") - f, err := repo.SaveFile(repository.SaveFileArgs{ + f, err := env.Repository.SaveFile(repository.SaveFileArgs{ FileName: "test.txt", FileSize: int64(buf.Len()), FileType: model.FileTypeUserFile, @@ -221,62 +214,62 @@ func mustMakeFile(t *testing.T, repo repository.Repository) model.FileMeta { return f } -func mustMakeTag(t *testing.T, repo repository.Repository, userID uuid.UUID, tagText string) uuid.UUID { +func (env *Env) mustMakeTag(t *testing.T, userID uuid.UUID, tagText string) uuid.UUID { t.Helper() if tagText == rand { tagText = random.AlphaNumeric(20) } - tag, err := repo.GetOrCreateTag(tagText) + tag, err := env.Repository.GetOrCreateTag(tagText) require.NoError(t, err) - require.NoError(t, repo.AddUserTag(userID, tag.ID)) + require.NoError(t, env.Repository.AddUserTag(userID, tag.ID)) return tag.ID } -func mustStarChannel(t *testing.T, repo repository.Repository, userID, channelID uuid.UUID) { +func (env *Env) mustStarChannel(t *testing.T, userID, channelID uuid.UUID) { t.Helper() - require.NoError(t, repo.AddStar(userID, channelID)) + require.NoError(t, env.Repository.AddStar(userID, channelID)) } -func mustMakeUserGroup(t *testing.T, repo repository.Repository, name string, adminID uuid.UUID) *model.UserGroup { +func (env *Env) mustMakeUserGroup(t *testing.T, name string, adminID uuid.UUID) *model.UserGroup { t.Helper() if name == rand { name = random.AlphaNumeric(20) } - g, err := repo.CreateUserGroup(name, "", "", adminID) + g, err := env.Repository.CreateUserGroup(name, "", "", adminID) require.NoError(t, err) return g } -func mustAddUserToGroup(t *testing.T, repo repository.Repository, userID, groupID uuid.UUID) { +func (env *Env) mustAddUserToGroup(t *testing.T, userID, groupID uuid.UUID) { t.Helper() - require.NoError(t, repo.AddUserToGroup(userID, groupID, "")) + require.NoError(t, env.Repository.AddUserToGroup(userID, groupID, "")) } -func mustMakeWebhook(t *testing.T, repo repository.Repository, name string, channelID, creatorID uuid.UUID, secret string) model.Webhook { +func (env *Env) mustMakeWebhook(t *testing.T, name string, channelID, creatorID uuid.UUID, secret string) model.Webhook { t.Helper() if name == rand { name = random.AlphaNumeric(20) } - w, err := repo.CreateWebhook(name, "", channelID, creatorID, secret) + w, err := env.Repository.CreateWebhook(name, "", channelID, creatorID, secret) require.NoError(t, err) return w } -func mustMakeStamp(t *testing.T, repo repository.Repository, name string, userID uuid.UUID) *model.Stamp { +func (env *Env) mustMakeStamp(t *testing.T, name string, userID uuid.UUID) *model.Stamp { t.Helper() if name == rand { name = random.AlphaNumeric(20) } - fileID, err := repository.GenerateIconFile(repo, name) + fileID, err := repository.GenerateIconFile(env.Repository, name) require.NoError(t, err) - s, err := repo.CreateStamp(repository.CreateStampArgs{Name: name, FileID: fileID, CreatorID: userID}) + s, err := env.Repository.CreateStamp(repository.CreateStampArgs{Name: name, FileID: fileID, CreatorID: userID}) require.NoError(t, err) return s } -func mustChangeChannelSubscription(t *testing.T, repo repository.Repository, channelID, userID uuid.UUID) { +func (env *Env) mustChangeChannelSubscription(t *testing.T, channelID, userID uuid.UUID) { t.Helper() - require.NoError(t, repo.ChangeChannelSubscription(channelID, repository.ChangeChannelSubscriptionArgs{Subscription: map[uuid.UUID]model.ChannelSubscribeLevel{userID: model.ChannelSubscribeLevelMarkAndNotify}})) + require.NoError(t, env.Repository.ChangeChannelSubscription(channelID, repository.ChangeChannelSubscriptionArgs{Subscription: map[uuid.UUID]model.ChannelSubscribeLevel{userID: model.ChannelSubscribeLevelMarkAndNotify}})) } type rbacImpl struct { diff --git a/router/v1/sessions.go b/router/v1/sessions.go index 1373927ae..b3360f916 100644 --- a/router/v1/sessions.go +++ b/router/v1/sessions.go @@ -1,10 +1,10 @@ package v1 import ( + "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" "net/http" "time" ) @@ -13,28 +13,21 @@ import ( func (h *Handlers) GetMySessions(c echo.Context) error { userID := getRequestUserID(c) - ses, err := sessions.GetByUserID(userID) + ses, err := h.SessStore.GetSessionsByUserID(userID) if err != nil { return herror.InternalServerError(err) } type response struct { - ID string `json:"id"` - LastIP string `json:"lastIP"` - LastUserAgent string `json:"lastUserAgent"` - LastAccess time.Time `json:"lastAccess"` - CreatedAt time.Time `json:"createdAt"` + ID uuid.UUID `json:"id"` + CreatedAt time.Time `json:"createdAt"` } res := make([]response, len(ses)) for k, v := range ses { - referenceID, created, lastAccess, lastIP, lastUserAgent := v.GetSessionInfo() res[k] = response{ - ID: referenceID.String(), - LastIP: lastIP, - LastUserAgent: lastUserAgent, - LastAccess: lastAccess, - CreatedAt: created, + ID: v.RefID(), + CreatedAt: v.CreatedAt(), } } @@ -45,8 +38,7 @@ func (h *Handlers) GetMySessions(c echo.Context) error { func (h *Handlers) DeleteAllMySessions(c echo.Context) error { userID := getRequestUserID(c) - err := sessions.DestroyByUserID(userID) - if err != nil { + if err := h.SessStore.RevokeSessionsByUserID(userID); err != nil { return herror.InternalServerError(err) } @@ -55,11 +47,9 @@ func (h *Handlers) DeleteAllMySessions(c echo.Context) error { // DeleteMySession DELETE /users/me/sessions/:referenceID func (h *Handlers) DeleteMySession(c echo.Context) error { - userID := getRequestUserID(c) - referenceID := getRequestParamAsUUID(c, consts.ParamReferenceID) + refID := getRequestParamAsUUID(c, consts.ParamReferenceID) - err := sessions.DestroyByReferenceID(userID, referenceID) - if err != nil { + if err := h.SessStore.RevokeSessionByRefID(refID); err != nil { return herror.InternalServerError(err) } diff --git a/router/v1/stars_test.go b/router/v1/stars_test.go index c17a0b189..f9c46e654 100644 --- a/router/v1/stars_test.go +++ b/router/v1/stars_test.go @@ -3,21 +3,21 @@ package v1 import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "net/http" "testing" ) func TestHandlers_GetStars(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common3) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common3) - channel := mustMakeChannel(t, repo, rand) - mustStarChannel(t, repo, testUser.GetID(), channel.ID) + channel := env.mustMakeChannel(t, rand) + env.mustStarChannel(t, testUser.GetID(), channel.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/stars"). Expect(). Status(http.StatusUnauthorized) @@ -25,9 +25,9 @@ func TestHandlers_GetStars(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/stars"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -38,13 +38,13 @@ func TestHandlers_GetStars(t *testing.T) { func TestHandlers_PutStars(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common3) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common3) - channel := mustMakeChannel(t, repo, rand) + channel := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/stars/{channelID}", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -52,13 +52,13 @@ func TestHandlers_PutStars(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/stars/{channelID}", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - a, err := repo.GetStaredChannels(testUser.GetID()) + a, err := env.Repository.GetStaredChannels(testUser.GetID()) require.NoError(t, err) assert.Len(t, a, 1) assert.Contains(t, a, channel.ID) @@ -67,14 +67,14 @@ func TestHandlers_PutStars(t *testing.T) { func TestHandlers_DeleteStars(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common3) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common3) - channel := mustMakeChannel(t, repo, rand) - mustStarChannel(t, repo, testUser.GetID(), channel.ID) + channel := env.mustMakeChannel(t, rand) + env.mustStarChannel(t, testUser.GetID(), channel.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/me/stars/{channelID}", channel.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -82,12 +82,12 @@ func TestHandlers_DeleteStars(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/me/stars/{channelID}", channel.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - a, err := repo.GetStaredChannels(testUser.GetID()) + a, err := env.Repository.GetStaredChannels(testUser.GetID()) require.NoError(t, err) assert.Empty(t, a) }) diff --git a/router/v1/tags_test.go b/router/v1/tags_test.go index e6c8fffba..11c5592c9 100644 --- a/router/v1/tags_test.go +++ b/router/v1/tags_test.go @@ -4,7 +4,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" random2 "github.com/traPtitech/traQ/utils/random" "net/http" "testing" @@ -12,11 +12,11 @@ import ( func TestHandlers_PostUserTag(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common3) + env, _, _, s, _, user, _ := setupWithUsers(t, common3) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/users/{userID}/tags", user.GetID().String()). Expect(). Status(http.StatusUnauthorized) @@ -24,15 +24,15 @@ func TestHandlers_PostUserTag(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) tag := random2.AlphaNumeric(20) e.POST("/api/1.0/users/{userID}/tags", user.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"tag": tag}). Expect(). Status(http.StatusCreated) - a, err := repo.GetUserTagsByUserID(user.GetID()) + a, err := env.Repository.GetUserTagsByUserID(user.GetID()) require.NoError(t, err) assert.Len(t, a, 1) }) @@ -40,15 +40,15 @@ func TestHandlers_PostUserTag(t *testing.T) { func TestHandlers_GetUserTags(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common3) + env, _, _, s, _, user, _ := setupWithUsers(t, common3) for i := 0; i < 5; i++ { - mustMakeTag(t, repo, user.GetID(), rand) + env.mustMakeTag(t, user.GetID(), rand) } t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/tags", user.GetID().String()). Expect(). Status(http.StatusUnauthorized) @@ -56,9 +56,9 @@ func TestHandlers_GetUserTags(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/tags", user.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -70,14 +70,14 @@ func TestHandlers_GetUserTags(t *testing.T) { func TestHandlers_PatchUserTag(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common3) + env, _, _, s, _, user, _ := setupWithUsers(t, common3) - other := mustMakeUser(t, repo, rand) - tag := mustMakeTag(t, repo, user.GetID(), rand) + other := env.mustMakeUser(t, rand) + tag := env.mustMakeTag(t, user.GetID(), rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/users/{userID}/tags/{tagID}", user.GetID().String(), tag.String()). WithJSON(map[string]bool{"isLocked": true}). Expect(). @@ -86,23 +86,23 @@ func TestHandlers_PatchUserTag(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/users/{userID}/tags/{tagID}", user.GetID().String(), tag.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]bool{"isLocked": true}). Expect(). Status(http.StatusNoContent) - ut, err := repo.GetUserTag(user.GetID(), tag) + ut, err := env.Repository.GetUserTag(user.GetID(), tag) require.NoError(t, err) assert.True(t, ut.GetIsLocked()) }) t.Run("Failure1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/users/{userID}/tags/{tagID}", user.GetID().String(), tag.String()). - WithCookie(sessions.CookieName, generateSession(t, other.GetID())). + WithCookie(session.CookieName, env.generateSession(t, other.GetID())). WithJSON(map[string]bool{"isLocked": true}). Expect(). Status(http.StatusForbidden) @@ -111,13 +111,13 @@ func TestHandlers_PatchUserTag(t *testing.T) { func TestHandlers_DeleteUserTag(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common3) + env, _, _, s, _, user, _ := setupWithUsers(t, common3) - tag := mustMakeTag(t, repo, user.GetID(), rand) + tag := env.mustMakeTag(t, user.GetID(), rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/{userID}/tags/{tagID}", user.GetID().String(), tag.String()). Expect(). Status(http.StatusUnauthorized) @@ -125,13 +125,13 @@ func TestHandlers_DeleteUserTag(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/users/{userID}/tags/{tagID}", user.GetID().String(), tag.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - _, err := repo.GetUserTag(user.GetID(), tag) + _, err := env.Repository.GetUserTag(user.GetID(), tag) require.Equal(t, repository.ErrNotFound, err) }) } diff --git a/router/v1/user_group_test.go b/router/v1/user_group_test.go index 789107a9b..fc7886432 100644 --- a/router/v1/user_group_test.go +++ b/router/v1/user_group_test.go @@ -4,7 +4,7 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" random2 "github.com/traPtitech/traQ/utils/random" "net/http" "testing" @@ -12,15 +12,15 @@ import ( func TestHandlers_GetUserGroups(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, _, adminUser := setupWithUsers(t, s1) + env, _, _, s, _, _, adminUser := setupWithUsers(t, s1) - mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustMakeUserGroup(t, repo, rand, adminUser.GetID()) + env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustMakeUserGroup(t, rand, adminUser.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups"). Expect(). Status(http.StatusUnauthorized) @@ -28,9 +28,9 @@ func TestHandlers_GetUserGroups(t *testing.T) { t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -42,11 +42,11 @@ func TestHandlers_GetUserGroups(t *testing.T) { func TestHandlers_PostUserGroups(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, user, adminUser := setupWithUsers(t, common5) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) name := random2.AlphaNumeric(20) e.POST("/api/1.0/groups"). WithJSON(map[string]interface{}{"name": name, "description": name}). @@ -56,9 +56,9 @@ func TestHandlers_PostUserGroups(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": true}). Expect(). Status(http.StatusBadRequest) @@ -66,11 +66,11 @@ func TestHandlers_PostUserGroups(t *testing.T) { t.Run("conflict", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) name := random2.AlphaNumeric(20) - mustMakeUserGroup(t, repo, name, adminUser.GetID()) + env.mustMakeUserGroup(t, name, adminUser.GetID()) e.POST("/api/1.0/groups"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": name, "description": name}). Expect(). Status(http.StatusConflict) @@ -78,10 +78,10 @@ func TestHandlers_PostUserGroups(t *testing.T) { t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) name := random2.AlphaNumeric(20) obj := e.POST("/api/1.0/groups"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": name, "description": name}). Expect(). Status(http.StatusCreated). @@ -98,14 +98,14 @@ func TestHandlers_PostUserGroups(t *testing.T) { func TestHandlers_GetUserGroup(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, user, adminUser := setupWithUsers(t, common5) - g := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustAddUserToGroup(t, repo, user.GetID(), g.ID) + g := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustAddUserToGroup(t, user.GetID(), g.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups/{groupID}", g.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -113,18 +113,18 @@ func TestHandlers_GetUserGroup(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups/{groupID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) obj := e.GET("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -140,14 +140,14 @@ func TestHandlers_GetUserGroup(t *testing.T) { func TestHandlers_PatchUserGroup(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, user, adminUser := setupWithUsers(t, common5) - user2 := mustMakeUser(t, repo, rand) - g := mustMakeUserGroup(t, repo, rand, user.GetID()) + user2 := env.mustMakeUser(t, rand) + g := env.mustMakeUserGroup(t, rand, user.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/groups/{groupID}", g.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -155,9 +155,9 @@ func TestHandlers_PatchUserGroup(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/groups/{groupID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"description": "aaa"}). Expect(). Status(http.StatusNotFound) @@ -165,9 +165,9 @@ func TestHandlers_PatchUserGroup(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": true}). Expect(). Status(http.StatusBadRequest) @@ -175,11 +175,11 @@ func TestHandlers_PatchUserGroup(t *testing.T) { t.Run("conflict", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) name := random2.AlphaNumeric(20) - mustMakeUserGroup(t, repo, name, adminUser.GetID()) + env.mustMakeUserGroup(t, name, adminUser.GetID()) e.PATCH("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": name}). Expect(). Status(http.StatusConflict) @@ -187,9 +187,9 @@ func TestHandlers_PatchUserGroup(t *testing.T) { t.Run("forbidden", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, user2.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user2.GetID())). WithJSON(map[string]interface{}{"description": "aaa"}). Expect(). Status(http.StatusForbidden) @@ -197,16 +197,16 @@ func TestHandlers_PatchUserGroup(t *testing.T) { t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) - g := mustMakeUserGroup(t, repo, rand, user.GetID()) + e := env.makeExp(t) + g := env.mustMakeUserGroup(t, rand, user.GetID()) name := random2.AlphaNumeric(20) e.PATCH("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"name": name, "description": "aaa"}). Expect(). Status(http.StatusNoContent) - a, err := repo.GetUserGroup(g.ID) + a, err := env.Repository.GetUserGroup(g.ID) if assert.NoError(t, err) { assert.Equal(t, a.Name, name) assert.Equal(t, a.Description, "aaa") @@ -217,13 +217,13 @@ func TestHandlers_PatchUserGroup(t *testing.T) { func TestHandlers_DeleteUserGroup(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common5) + env, _, _, s, _, user, _ := setupWithUsers(t, common5) - g := mustMakeUserGroup(t, repo, rand, user.GetID()) + g := env.mustMakeUserGroup(t, rand, user.GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}", g.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -231,47 +231,47 @@ func TestHandlers_DeleteUserGroup(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("forbidden", func(t *testing.T) { t.Parallel() - user2 := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + user2 := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, user2.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user2.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("ok", func(t *testing.T) { t.Parallel() - g := mustMakeUserGroup(t, repo, rand, user.GetID()) - e := makeExp(t, server) + g := env.mustMakeUserGroup(t, rand, user.GetID()) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - _, err := repo.GetUserGroup(g.ID) + _, err := env.Repository.GetUserGroup(g.ID) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) } func TestHandlers_GetUserGroupMembers(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, user, adminUser := setupWithUsers(t, common5) - g := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustAddUserToGroup(t, repo, user.GetID(), g.ID) + g := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustAddUserToGroup(t, user.GetID(), g.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups/{groupID}/members", g.ID.String()). Expect(). Status(http.StatusUnauthorized) @@ -279,18 +279,18 @@ func TestHandlers_GetUserGroupMembers(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups/{groupID}/members", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/groups/{groupID}/members", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -301,13 +301,13 @@ func TestHandlers_GetUserGroupMembers(t *testing.T) { func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common5) - g := mustMakeUserGroup(t, repo, rand, user.GetID()) - user2 := mustMakeUser(t, repo, rand) + env, _, _, s, _, user, _ := setupWithUsers(t, common5) + g := env.mustMakeUserGroup(t, rand, user.GetID()) + user2 := env.mustMakeUser(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", g.ID.String()). WithJSON(map[string]interface{}{"userId": user.GetID()}). Expect(). @@ -316,9 +316,9 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"userId": user.GetID()}). Expect(). Status(http.StatusNotFound) @@ -326,9 +326,9 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Run("bad request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"userId": true}). Expect(). Status(http.StatusBadRequest) @@ -336,9 +336,9 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Run("forbidden", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", g.ID.String()). - WithCookie(sessions.CookieName, generateSession(t, user2.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user2.GetID())). WithJSON(map[string]interface{}{"userId": user.GetID()}). Expect(). Status(http.StatusForbidden) @@ -346,9 +346,9 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Run("unknown user", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]uuid.UUID{"userId": uuid.Must(uuid.NewV4())}). Expect(). Status(http.StatusBadRequest) @@ -356,14 +356,14 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/groups/{groupID}/members", g.ID.String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"userId": user.GetID()}). Expect(). Status(http.StatusNoContent) - ids, err := repo.GetUserIDs(repository.UsersQuery{}.GMemberOf(g.ID)) + ids, err := env.Repository.GetUserIDs(repository.UsersQuery{}.GMemberOf(g.ID)) if assert.NoError(t, err) { assert.ElementsMatch(t, ids, []uuid.UUID{user.GetID()}) } @@ -372,14 +372,14 @@ func TestHandlers_PostUserGroupMembers(t *testing.T) { func TestHandlers_DeleteUserGroupMembers(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common5) - g := mustMakeUserGroup(t, repo, rand, user.GetID()) - mustAddUserToGroup(t, repo, user.GetID(), g.ID) - user2 := mustMakeUser(t, repo, rand) + env, _, _, s, _, user, _ := setupWithUsers(t, common5) + g := env.mustMakeUserGroup(t, rand, user.GetID()) + env.mustAddUserToGroup(t, user.GetID(), g.ID) + user2 := env.mustMakeUser(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}/members/{userID}", g.ID.String(), user.GetID().String()). Expect(). Status(http.StatusUnauthorized) @@ -387,40 +387,40 @@ func TestHandlers_DeleteUserGroupMembers(t *testing.T) { t.Run("not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}/members/{userID}", uuid.Must(uuid.NewV4()), user.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("unknown user", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}/members/{userID}", g.ID.String(), uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) }) t.Run("forbidden", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}/members/{userID}", g.ID.String(), user.GetID().String()). - WithCookie(sessions.CookieName, generateSession(t, user2.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user2.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/groups/{groupID}/members/{userID}", g.ID.String(), user.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - ids, err := repo.GetUserIDs(repository.UsersQuery{}.GMemberOf(g.ID)) + ids, err := env.Repository.GetUserIDs(repository.UsersQuery{}.GMemberOf(g.ID)) if assert.NoError(t, err) { assert.Len(t, ids, 0) } @@ -429,16 +429,16 @@ func TestHandlers_DeleteUserGroupMembers(t *testing.T) { func TestHandlers_GetMyBelongingGroup(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, user, adminUser := setupWithUsers(t, common5) - g1 := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - g2 := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustAddUserToGroup(t, repo, user.GetID(), g1.ID) - mustAddUserToGroup(t, repo, user.GetID(), g2.ID) + g1 := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + g2 := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustAddUserToGroup(t, user.GetID(), g1.ID) + env.mustAddUserToGroup(t, user.GetID(), g2.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/groups"). Expect(). Status(http.StatusUnauthorized) @@ -446,9 +446,9 @@ func TestHandlers_GetMyBelongingGroup(t *testing.T) { t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me/groups"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -459,17 +459,17 @@ func TestHandlers_GetMyBelongingGroup(t *testing.T) { func TestHandlers_GetUserBelongingGroup(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, _, adminUser := setupWithUsers(t, common5) + env, _, _, s, _, _, adminUser := setupWithUsers(t, common5) - user := mustMakeUser(t, repo, rand) - g1 := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - g2 := mustMakeUserGroup(t, repo, rand, adminUser.GetID()) - mustAddUserToGroup(t, repo, user.GetID(), g1.ID) - mustAddUserToGroup(t, repo, user.GetID(), g2.ID) + user := env.mustMakeUser(t, rand) + g1 := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + g2 := env.mustMakeUserGroup(t, rand, adminUser.GetID()) + env.mustAddUserToGroup(t, user.GetID(), g1.ID) + env.mustAddUserToGroup(t, user.GetID(), g2.ID) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/groups", user.GetID().String()). Expect(). Status(http.StatusUnauthorized) @@ -477,18 +477,18 @@ func TestHandlers_GetUserBelongingGroup(t *testing.T) { t.Run("unknown user", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/groups", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("ok", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}/groups", user.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). diff --git a/router/v1/users.go b/router/v1/users.go index ae7890c85..6963949e8 100644 --- a/router/v1/users.go +++ b/router/v1/users.go @@ -18,7 +18,6 @@ import ( "github.com/skip2/go-qrcode" "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" "github.com/traPtitech/traQ/service/rbac/role" ) @@ -56,12 +55,7 @@ func (h *Handlers) PostLogin(c echo.Context) error { } h.L(c).Info("an api login attempt succeeded", zap.String("username", req.Name)) - sess, err := sessions.Get(c.Response(), c.Request(), true) - if err != nil { - return herror.InternalServerError(err) - } - - if err := sess.SetUser(user.GetID()); err != nil { + if _, err := h.SessStore.RenewSession(c, user.GetID()); err != nil { return herror.InternalServerError(err) } @@ -73,16 +67,9 @@ func (h *Handlers) PostLogin(c echo.Context) error { // PostLogout POST /logout func (h *Handlers) PostLogout(c echo.Context) error { - sess, err := sessions.Get(c.Response(), c.Request(), false) - if err != nil { + if err := h.SessStore.RevokeSession(c); err != nil { return herror.InternalServerError(err) } - if sess != nil { - if err := sess.Destroy(c.Response(), c.Request()); err != nil { - return herror.InternalServerError(err) - } - } - if redirect := c.QueryParam("redirect"); len(redirect) > 0 { return c.Redirect(http.StatusFound, redirect) } @@ -196,7 +183,7 @@ func (h *Handlers) PutUserPassword(c echo.Context) error { if err := bindAndValidate(c, &req); err != nil { return err } - return utils.ChangeUserPassword(c, h.Repo, getRequestParamAsUUID(c, consts.ParamUserID), req.NewPassword) + return utils.ChangeUserPassword(c, h.Repo, h.SessStore, getRequestParamAsUUID(c, consts.ParamUserID), req.NewPassword) } // GetUserIcon GET /users/:userID/icon @@ -274,7 +261,7 @@ func (h *Handlers) PutPassword(c echo.Context) error { return echo.NewHTTPError(http.StatusUnauthorized, "current password is wrong") } - return utils.ChangeUserPassword(c, h.Repo, user.GetID(), req.NewPassword) + return utils.ChangeUserPassword(c, h.Repo, h.SessStore, user.GetID(), req.NewPassword) } // GetMyQRCode GET /users/me/qr-code diff --git a/router/v1/users_test.go b/router/v1/users_test.go index 0ee1422cb..d8893ec28 100644 --- a/router/v1/users_test.go +++ b/router/v1/users_test.go @@ -4,7 +4,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traPtitech/traQ/repository" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/utils/optional" "strings" "testing" @@ -14,11 +14,11 @@ import ( func TestHandlers_GetUsers(t *testing.T) { t.Parallel() - _, server, _, _, session, _ := setup(t, s2) + env, _, _, s, _ := setup(t, s2) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users"). Expect(). Status(http.StatusUnauthorized) @@ -26,9 +26,9 @@ func TestHandlers_GetUsers(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -40,11 +40,11 @@ func TestHandlers_GetUsers(t *testing.T) { func TestHandlers_GetMe(t *testing.T) { t.Parallel() - _, server, _, _, session, _, testUser, _ := setupWithUsers(t, common4) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common4) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me"). Expect(). Status(http.StatusUnauthorized) @@ -52,9 +52,9 @@ func TestHandlers_GetMe(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/me"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -67,11 +67,11 @@ func TestHandlers_GetMe(t *testing.T) { func TestHandlers_GetUserByID(t *testing.T) { t.Parallel() - _, server, _, _, session, _, testUser, _ := setupWithUsers(t, common4) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common4) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}", testUser.GetID().String()). Expect(). Status(http.StatusUnauthorized) @@ -79,9 +79,9 @@ func TestHandlers_GetUserByID(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/users/{userID}", testUser.GetID().String()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -94,11 +94,11 @@ func TestHandlers_GetUserByID(t *testing.T) { func TestHandlers_PatchMe(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, user, _ := setupWithUsers(t, common4) + env, _, _, s, _, user, _ := setupWithUsers(t, common4) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/users/me"). Expect(). Status(http.StatusUnauthorized) @@ -106,16 +106,16 @@ func TestHandlers_PatchMe(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) newDisplay := "renamed" newTwitter := "test" e.PATCH("/api/1.0/users/me"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"displayName": newDisplay, "twitterId": newTwitter}). Expect(). Status(http.StatusNoContent) - u, err := repo.GetUser(user.GetID(), true) + u, err := env.Repository.GetUser(user.GetID(), true) require.NoError(t, err) assert.Equal(t, newDisplay, u.GetDisplayName()) assert.Equal(t, newTwitter, u.GetTwitterID()) @@ -123,17 +123,17 @@ func TestHandlers_PatchMe(t *testing.T) { t.Run("Successful2", func(t *testing.T) { t.Parallel() - user := mustMakeUser(t, repo, rand) - require.NoError(t, repo.UpdateUser(user.GetID(), repository.UpdateUserArgs{DisplayName: optional.StringFrom("test")})) + user := env.mustMakeUser(t, rand) + require.NoError(t, env.Repository.UpdateUser(user.GetID(), repository.UpdateUserArgs{DisplayName: optional.StringFrom("test")})) - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/users/me"). - WithCookie(sessions.CookieName, generateSession(t, user.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user.GetID())). WithJSON(map[string]string{"displayName": ""}). Expect(). Status(http.StatusNoContent) - u, err := repo.GetUser(user.GetID(), true) + u, err := env.Repository.GetUser(user.GetID(), true) require.NoError(t, err) assert.Equal(t, "", u.GetDisplayName()) }) @@ -141,11 +141,11 @@ func TestHandlers_PatchMe(t *testing.T) { func TestHandlers_PutPassword(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common4) + env, _, _, s, _ := setup(t, common4) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). Expect(). Status(http.StatusUnauthorized) @@ -153,9 +153,9 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("invalid body", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]interface{}{"password": 111, "newPassword": false}). Expect(). Status(http.StatusBadRequest) @@ -163,9 +163,9 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("invalid password1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"password": "test", "newPassword": "a"}). Expect(). Status(http.StatusBadRequest) @@ -173,9 +173,9 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("invalid password2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"password": "test", "newPassword": "アイウエオ"}). Expect(). Status(http.StatusBadRequest) @@ -183,9 +183,9 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("invalid password3", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"password": "test", "newPassword": strings.Repeat("a", 33)}). Expect(). Status(http.StatusBadRequest) @@ -193,9 +193,9 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("wrong password", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"password": "wrong password", "newPassword": strings.Repeat("a", 20)}). Expect(). Status(http.StatusUnauthorized) @@ -203,17 +203,17 @@ func TestHandlers_PutPassword(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - user := mustMakeUser(t, repo, rand) + user := env.mustMakeUser(t, rand) - e := makeExp(t, server) + e := env.makeExp(t) newPassword := strings.Repeat("a", 20) e.PUT("/api/1.0/users/me/password"). - WithCookie(sessions.CookieName, generateSession(t, user.GetID())). + WithCookie(session.CookieName, env.generateSession(t, user.GetID())). WithJSON(map[string]string{"password": "test", "newPassword": newPassword}). Expect(). Status(http.StatusNoContent) - u, err := repo.GetUser(user.GetID(), false) + u, err := env.Repository.GetUser(user.GetID(), false) require.NoError(t, err) assert.NoError(t, u.Authenticate(newPassword)) }) @@ -221,11 +221,11 @@ func TestHandlers_PutPassword(t *testing.T) { func TestHandlers_PutUserPassword(t *testing.T) { t.Parallel() - repo, server, _, _, session, adminSession, user, _ := setupWithUsers(t, common4) + env, _, _, s, adminSession, user, _ := setupWithUsers(t, common4) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -233,9 +233,9 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("Forbidden", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). WithJSON(map[string]string{"newPassword": "aaaaaaaaaaaaa"}). Expect(). Status(http.StatusForbidden) @@ -243,9 +243,9 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("invalid body", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]interface{}{"newPassword": false}). Expect(). Status(http.StatusBadRequest) @@ -253,9 +253,9 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("invalid password1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]string{"newPassword": "a"}). Expect(). Status(http.StatusBadRequest) @@ -263,9 +263,9 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("invalid password2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]string{"newPassword": "アイウエオ"}). Expect(). Status(http.StatusBadRequest) @@ -273,9 +273,9 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("invalid password3", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]string{"newPassword": strings.Repeat("a", 33)}). Expect(). Status(http.StatusBadRequest) @@ -283,17 +283,17 @@ func TestHandlers_PutUserPassword(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - user := mustMakeUser(t, repo, rand) + user := env.mustMakeUser(t, rand) - e := makeExp(t, server) + e := env.makeExp(t) newPass := strings.Repeat("a", 20) e.PUT("/api/1.0/users/{userID}/password", user.GetID()). - WithCookie(sessions.CookieName, adminSession). + WithCookie(session.CookieName, adminSession). WithJSON(map[string]string{"newPassword": newPass}). Expect(). Status(http.StatusNoContent) - u, err := repo.GetUser(user.GetID(), false) + u, err := env.Repository.GetUser(user.GetID(), false) require.NoError(t, err) assert.NoError(t, u.Authenticate(newPass)) }) @@ -301,11 +301,11 @@ func TestHandlers_PutUserPassword(t *testing.T) { func TestHandlers_PostLogin(t *testing.T) { t.Parallel() - _, server, _, _, _, _, user, _ := setupWithUsers(t, common4) + env, _, _, _, _, user, _ := setupWithUsers(t, common4) t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/login"). WithJSON(map[string]string{"name": user.GetName(), "pass": "test"}). Expect(). @@ -314,7 +314,7 @@ func TestHandlers_PostLogin(t *testing.T) { t.Run("wrong password", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/login"). WithJSON(map[string]string{"name": user.GetName(), "pass": "wrong_password"}). Expect(). diff --git a/router/v1/webhooks_test.go b/router/v1/webhooks_test.go index 3407c5f43..fd3c3284e 100644 --- a/router/v1/webhooks_test.go +++ b/router/v1/webhooks_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/utils/hmac" random2 "github.com/traPtitech/traQ/utils/random" "net/http" @@ -16,15 +16,15 @@ import ( func TestHandlers_GetWebhooks(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) for i := 0; i < 10; i++ { - mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "") + env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "") } t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks"). Expect(). Status(http.StatusUnauthorized) @@ -32,9 +32,9 @@ func TestHandlers_GetWebhooks(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks"). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -45,10 +45,10 @@ func TestHandlers_GetWebhooks(t *testing.T) { t.Run("Other user", func(t *testing.T) { t.Parallel() - u := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + u := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.GET("/api/1.0/webhooks"). - WithCookie(sessions.CookieName, generateSession(t, u.GetID())). + WithCookie(session.CookieName, env.generateSession(t, u.GetID())). Expect(). Status(http.StatusOK). JSON(). @@ -59,12 +59,12 @@ func TestHandlers_GetWebhooks(t *testing.T) { func TestHandlers_PostWebhooks(t *testing.T) { t.Parallel() - repo, server, _, _, session, _ := setup(t, common6) - ch := mustMakeChannel(t, repo, rand) + env, _, _, s, _ := setup(t, common6) + ch := env.mustMakeChannel(t, rand) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks"). WithJSON(map[string]string{"name": "test", "description": "test", "channelId": ch.ID.String()}). Expect(). @@ -73,10 +73,10 @@ func TestHandlers_PostWebhooks(t *testing.T) { t.Run("Bad Request (No channel)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks"). WithJSON(map[string]string{"name": "test", "description": "test"}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) @@ -85,10 +85,10 @@ func TestHandlers_PostWebhooks(t *testing.T) { t.Parallel() assert := assert.New(t) name := random2.AlphaNumeric(20) - e := makeExp(t, server) + e := env.makeExp(t) id := e.POST("/api/1.0/webhooks"). WithJSON(map[string]string{"name": name, "description": "test", "channelId": ch.ID.String()}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusCreated). JSON(). @@ -97,7 +97,7 @@ func TestHandlers_PostWebhooks(t *testing.T) { String(). Raw() - wb, err := repo.GetWebhook(uuid.FromStringOrNil(id)) + wb, err := env.Repository.GetWebhook(uuid.FromStringOrNil(id)) if assert.NoError(err) { assert.Equal(name, wb.GetName()) assert.Equal("test", wb.GetDescription()) @@ -109,13 +109,13 @@ func TestHandlers_PostWebhooks(t *testing.T) { func TestHandlers_GetWebhook(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "") + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "") t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}", wb.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -123,28 +123,28 @@ func TestHandlers_GetWebhook(t *testing.T) { t.Run("Not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Other user", func(t *testing.T) { t.Parallel() - u := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + u := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}", wb.GetID()). - WithCookie(sessions.CookieName, generateSession(t, u.GetID())). + WithCookie(session.CookieName, env.generateSession(t, u.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("Success", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) obj := e.GET("/api/1.0/webhooks/{webhookID}", wb.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -160,13 +160,13 @@ func TestHandlers_GetWebhook(t *testing.T) { func TestHandlers_PatchWebhook(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "secret") + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "secret") t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookId}", wb.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -174,41 +174,41 @@ func TestHandlers_PatchWebhook(t *testing.T) { t.Run("Not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookId}", uuid.Must(uuid.NewV4())). WithJSON(map[string]string{"name": strings.Repeat("a", 30)}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Other user", func(t *testing.T) { t.Parallel() - u := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + u := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookID}", wb.GetID()). WithJSON(map[string]string{"name": strings.Repeat("a", 30)}). - WithCookie(sessions.CookieName, generateSession(t, u.GetID())). + WithCookie(session.CookieName, env.generateSession(t, u.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("Bad Request", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithJSON(map[string]string{"name": strings.Repeat("a", 40)}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) t.Run("Bad Request (Channel Not found)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithJSON(map[string]uuid.UUID{"channelId": uuid.Must(uuid.NewV4())}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) @@ -219,15 +219,15 @@ func TestHandlers_PatchWebhook(t *testing.T) { name := random2.AlphaNumeric(20) desc := random2.AlphaNumeric(20) secret := random2.AlphaNumeric(20) - ch := mustMakeChannel(t, repo, rand) - e := makeExp(t, server) + ch := env.mustMakeChannel(t, rand) + e := env.makeExp(t) e.PATCH("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithJSON(map[string]string{"name": name, "description": desc, "channelId": ch.ID.String(), "secret": secret}). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - wb, err := repo.GetWebhook(wb.GetID()) + wb, err := env.Repository.GetWebhook(wb.GetID()) require.NoError(err) assert.Equal(name, wb.GetName()) assert.Equal(desc, wb.GetDescription()) @@ -238,13 +238,13 @@ func TestHandlers_PatchWebhook(t *testing.T) { func TestHandlers_DeleteWebhook(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "secret") + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "secret") t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/webhooks/{webhookId}", wb.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -252,46 +252,46 @@ func TestHandlers_DeleteWebhook(t *testing.T) { t.Run("Not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.DELETE("/api/1.0/webhooks/{webhookId}", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Other user", func(t *testing.T) { t.Parallel() - u := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + u := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.DELETE("/api/1.0/webhooks/{webhookID}", wb.GetID()). - WithCookie(sessions.CookieName, generateSession(t, u.GetID())). + WithCookie(session.CookieName, env.generateSession(t, u.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("Success", func(t *testing.T) { t.Parallel() - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "secret") - e := makeExp(t, server) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "secret") + e := env.makeExp(t) e.DELETE("/api/1.0/webhooks/{webhookId}", wb.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNoContent) - _, err := repo.GetWebhook(wb.GetID()) + _, err := env.Repository.GetWebhook(wb.GetID()) assert.EqualError(t, err, repository.ErrNotFound.Error()) }) } func TestHandlers_PutWebhookIcon(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "secret") + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "secret") t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookId}/icon", wb.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -299,50 +299,50 @@ func TestHandlers_PutWebhookIcon(t *testing.T) { t.Run("Not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookId}/icon", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) t.Run("Other user", func(t *testing.T) { t.Parallel() - u := mustMakeUser(t, repo, rand) - e := makeExp(t, server) + u := env.mustMakeUser(t, rand) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookID}/icon", wb.GetID()). - WithCookie(sessions.CookieName, generateSession(t, u.GetID())). + WithCookie(session.CookieName, env.generateSession(t, u.GetID())). Expect(). Status(http.StatusForbidden) }) t.Run("Bad Request (No file)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookId}/icon", wb.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) t.Run("Bad Request (Not image)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookId}/icon", wb.GetID()). WithMultipart(). WithFileBytes("file", "test.txt", []byte("text file")). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) t.Run("Bad Request (Bad image)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.PUT("/api/1.0/webhooks/{webhookId}/icon", wb.GetID()). WithMultipart(). WithFileBytes("file", "test.png", []byte("text file")). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusBadRequest) }) @@ -350,13 +350,13 @@ func TestHandlers_PutWebhookIcon(t *testing.T) { func TestHandlers_PostWebhook(t *testing.T) { t.Parallel() - repo, server, _, _, _, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "secret") + env, _, _, _, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "secret") t.Run("Not found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", uuid.Must(uuid.NewV4())). Expect(). Status(http.StatusNotFound) @@ -364,7 +364,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("UnsupportedMediaType", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithJSON(map[string]string{"test": "test"}). Expect(). @@ -373,7 +373,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("Bad Request (No Body)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText(""). Expect(). @@ -382,7 +382,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("Bad Request (Missing Signature)", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText("test"). Expect(). @@ -391,7 +391,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("Unauthorized", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText("test"). WithHeader(consts.HeaderSignature, "abcdef"). @@ -402,7 +402,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("Bad Request (Nil Channel)", func(t *testing.T) { t.Parallel() body := "test" - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText(body). WithHeader(consts.HeaderSignature, hex.EncodeToString(hmac.SHA1([]byte(body), wb.GetSecret()))). @@ -414,7 +414,7 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Run("Bad Request (Channel not found)", func(t *testing.T) { t.Parallel() body := "test" - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText(body). WithHeader(consts.HeaderSignature, hex.EncodeToString(hmac.SHA1([]byte(body), wb.GetSecret()))). @@ -427,14 +427,14 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Parallel() assert, require := assertAndRequire(t) body := "test" - e := makeExp(t, server) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText(body). WithHeader(consts.HeaderSignature, hex.EncodeToString(hmac.SHA1([]byte(body), wb.GetSecret()))). Expect(). Status(http.StatusNoContent) - arr, _, err := repo.GetMessages(repository.MessagesQuery{Channel: ch.ID}) + arr, _, err := env.Repository.GetMessages(repository.MessagesQuery{Channel: ch.ID}) require.NoError(err) if assert.Len(arr, 1) { assert.Equal(wb.GetBotUserID(), arr[0].UserID) @@ -446,8 +446,8 @@ func TestHandlers_PostWebhook(t *testing.T) { t.Parallel() assert, require := assertAndRequire(t) body := "test" - ch := mustMakeChannel(t, repo, rand) - e := makeExp(t, server) + ch := env.mustMakeChannel(t, rand) + e := env.makeExp(t) e.POST("/api/1.0/webhooks/{webhookId}", wb.GetID()). WithText(body). WithHeader(consts.HeaderSignature, hex.EncodeToString(hmac.SHA1([]byte(body), wb.GetSecret()))). @@ -455,7 +455,7 @@ func TestHandlers_PostWebhook(t *testing.T) { Expect(). Status(http.StatusNoContent) - arr, _, err := repo.GetMessages(repository.MessagesQuery{Channel: ch.ID}) + arr, _, err := env.Repository.GetMessages(repository.MessagesQuery{Channel: ch.ID}) require.NoError(err) if assert.Len(arr, 1) { assert.Equal(wb.GetBotUserID(), arr[0].UserID) @@ -466,17 +466,17 @@ func TestHandlers_PostWebhook(t *testing.T) { func TestHandlers_GetWebhookMessages(t *testing.T) { t.Parallel() - repo, server, _, _, session, _, testUser, _ := setupWithUsers(t, common6) - ch := mustMakeChannel(t, repo, rand) - wb := mustMakeWebhook(t, repo, rand, ch.ID, testUser.GetID(), "") + env, _, _, s, _, testUser, _ := setupWithUsers(t, common6) + ch := env.mustMakeChannel(t, rand) + wb := env.mustMakeWebhook(t, rand, ch.ID, testUser.GetID(), "") for i := 0; i < 60; i++ { - mustMakeMessage(t, repo, wb.GetBotUserID(), ch.ID) + env.mustMakeMessage(t, wb.GetBotUserID(), ch.ID) } t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}/messages", wb.GetID()). Expect(). Status(http.StatusUnauthorized) @@ -484,9 +484,9 @@ func TestHandlers_GetWebhookMessages(t *testing.T) { t.Run("Successful1", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}/messages", wb.GetID()). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -497,11 +497,11 @@ func TestHandlers_GetWebhookMessages(t *testing.T) { t.Run("Successful2", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}/messages", wb.GetID()). WithQuery("limit", 3). WithQuery("offset", 1). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusOK). JSON(). @@ -512,9 +512,9 @@ func TestHandlers_GetWebhookMessages(t *testing.T) { t.Run("Not Found", func(t *testing.T) { t.Parallel() - e := makeExp(t, server) + e := env.makeExp(t) e.GET("/api/1.0/webhooks/{webhookID}/messages", uuid.Must(uuid.NewV4())). - WithCookie(sessions.CookieName, session). + WithCookie(session.CookieName, s). Expect(). Status(http.StatusNotFound) }) diff --git a/router/v3/public_test.go b/router/v3/public_test.go index d07d4eeb2..da680b2ee 100644 --- a/router/v3/public_test.go +++ b/router/v3/public_test.go @@ -7,9 +7,9 @@ import ( func TestHandlers_GetVersion(t *testing.T) { t.Parallel() - _, server := Setup(t, common) + env := Setup(t, common) - e := R(t, server) + e := env.R(t) obj := e.GET("/api/v3/version"). Expect(). Status(http.StatusOK). diff --git a/router/v3/router.go b/router/v3/router.go index 3be6d456b..79b737b68 100644 --- a/router/v3/router.go +++ b/router/v3/router.go @@ -6,6 +6,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/middlewares" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/service/counter" "github.com/traPtitech/traQ/service/imaging" "github.com/traPtitech/traQ/service/rbac" @@ -17,15 +18,16 @@ import ( ) type Handlers struct { - RBAC rbac.RBAC - Repo repository.Repository - WS *ws.Streamer - Hub *hub.Hub - Logger *zap.Logger - OC *counter.OnlineCounter - VM *viewer.Manager - WebRTC *webrtcv3.Manager - Imaging imaging.Processor + RBAC rbac.RBAC + Repo repository.Repository + WS *ws.Streamer + Hub *hub.Hub + Logger *zap.Logger + OC *counter.OnlineCounter + VM *viewer.Manager + WebRTC *webrtcv3.Manager + Imaging imaging.Processor + SessStore session.Store Config } @@ -47,7 +49,7 @@ func (h *Handlers) Setup(e *echo.Group) { bodyLimit := middlewares.RequestBodyLengthLimit retrieve := middlewares.NewParamRetriever(h.Repo) blockBot := middlewares.BlockBot(h.Repo) - nologin := middlewares.NoLogin() + nologin := middlewares.NoLogin(h.SessStore) requiresBotAccessPerm := middlewares.CheckBotAccessPerm(h.RBAC, h.Repo) requiresWebhookAccessPerm := middlewares.CheckWebhookAccessPerm(h.RBAC, h.Repo) @@ -58,7 +60,7 @@ func (h *Handlers) Setup(e *echo.Group) { requiresGroupAdminPerm := middlewares.CheckUserGroupAdminPerm(h.RBAC, h.Repo) requiresClipFolderAccessPerm := middlewares.CheckClipFolderAccessPerm(h.RBAC, h.Repo) - api := e.Group("/v3", middlewares.UserAuthenticate(h.Repo)) + api := e.Group("/v3", middlewares.UserAuthenticate(h.Repo, h.SessStore)) { apiUsers := api.Group("/users") { diff --git a/router/v3/router_test.go b/router/v3/router_test.go index b35d3589d..64eb2551f 100644 --- a/router/v3/router_test.go +++ b/router/v3/router_test.go @@ -12,9 +12,9 @@ import ( "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/extension" - "github.com/traPtitech/traQ/router/sessions" - imaging2 "github.com/traPtitech/traQ/service/imaging" - rbac2 "github.com/traPtitech/traQ/service/rbac" + "github.com/traPtitech/traQ/router/session" + "github.com/traPtitech/traQ/service/imaging" + "github.com/traPtitech/traQ/service/rbac" "github.com/traPtitech/traQ/service/rbac/role" "github.com/traPtitech/traQ/utils/random" "github.com/traPtitech/traQ/utils/storage" @@ -35,12 +35,7 @@ const ( rand = "random" ) -var ( - servers = map[string]*httptest.Server{} - dbConns = map[string]*gorm.DB{} - repositories = map[string]repository.Repository{} - hubs = map[string]*hub.Hub{} -) +var envs = map[string]*Env{} func TestMain(m *testing.M) { user := getEnvOrDefault("MARIADB_USERNAME", "root") @@ -55,6 +50,8 @@ func TestMain(m *testing.M) { } for _, key := range dbs { + env := &Env{} + // テスト用データベース接続 db, err := gorm.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true", user, pass, host, port, fmt.Sprintf("%s%s", dbPrefix, key))) if err != nil { @@ -64,20 +61,20 @@ func TestMain(m *testing.M) { if err := migration.DropAll(db); err != nil { panic(err) } - dbConns[key] = db - hub := hub.New() - hubs[key] = hub + env.DB = db + env.Hub = hub.New() + env.SessStore = session.NewMemorySessionStore() // テスト用リポジトリ作成 - repo, err := repository.NewGormRepository(db, storage.NewInMemoryFileStorage(), hub, zap.NewNop()) + repo, err := repository.NewGormRepository(db, storage.NewInMemoryFileStorage(), env.Hub, zap.NewNop()) if err != nil { panic(err) } if _, err := repo.Sync(); err != nil { panic(err) } - repositories[key] = repo + env.Repository = repo // テスト用サーバー作成 e := echo.New() @@ -86,17 +83,17 @@ func TestMain(m *testing.M) { e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) e.Use(extension.Wrap(repo)) - r, err := rbac2.New(db) + r, err := rbac.New(db) if err != nil { panic(err) } handlers := &Handlers{ - RBAC: r, - Repo: repo, - WS: nil, - Hub: hub, - Logger: zap.NewNop(), - Imaging: imaging2.NewProcessor(imaging2.Config{ + RBAC: r, + Repo: env.Repository, + Hub: env.Hub, + SessStore: env.SessStore, + Logger: zap.NewNop(), + Imaging: imaging.NewProcessor(imaging.Config{ MaxPixels: 1000 * 1000, Concurrency: 1, ThumbnailMaxSize: image.Pt(360, 480), @@ -108,52 +105,54 @@ func TestMain(m *testing.M) { }, } handlers.Setup(e.Group("/api")) - servers[key] = httptest.NewServer(e) + env.Server = httptest.NewServer(e) + + envs[key] = env } // テスト実行 code := m.Run() // 後始末 - for _, v := range servers { - v.Close() - } - for _, v := range dbConns { - v.Close() - } - for _, v := range hubs { - v.Close() + for _, env := range envs { + env.Server.Close() + env.DB.Close() + env.Hub.Close() } os.Exit(code) } +type Env struct { + Server *httptest.Server + DB *gorm.DB + Repository repository.Repository + Hub *hub.Hub + SessStore session.Store +} + // Setup テストセットアップ -func Setup(t *testing.T, server string) (repository.Repository, *httptest.Server) { +func Setup(t *testing.T, server string) *Env { t.Helper() - s, ok := servers[server] + env, ok := envs[server] if !ok { t.FailNow() } - repo := repositories[server] - return repo, s + return env } // S 指定ユーザーのAPIセッショントークンを発行 -func S(t *testing.T, userID uuid.UUID) string { +func (env *Env) S(t *testing.T, userID uuid.UUID) string { t.Helper() - require := require.New(t) - - sess, err := sessions.IssueNewSession("127.0.0.1", "test") - require.NoError(err) - require.NoError(sess.SetUser(userID)) - return sess.GetToken() + s, err := env.SessStore.IssueSession(userID, nil) + require.NoError(t, err) + return s.Token() } // R リクエストテスターを作成 -func R(t *testing.T, server *httptest.Server) *httpexpect.Expect { +func (env *Env) R(t *testing.T) *httpexpect.Expect { t.Helper() return httpexpect.WithConfig(httpexpect.Config{ - BaseURL: server.URL, + BaseURL: env.Server.URL, Reporter: httpexpect.NewAssertReporter(t), Printers: []httpexpect.Printer{ httpexpect.NewCurlPrinter(t), @@ -170,12 +169,12 @@ func R(t *testing.T, server *httptest.Server) *httpexpect.Expect { } // CreateUser ユーザーを必ず作成します -func CreateUser(t *testing.T, repo repository.Repository, userName string) model.UserInfo { +func (env *Env) CreateUser(t *testing.T, userName string) model.UserInfo { t.Helper() if userName == rand { userName = random.AlphaNumeric(32) } - u, err := repo.CreateUser(repository.CreateUserArgs{Name: userName, Password: "testtesttesttest", Role: role.User}) + u, err := env.Repository.CreateUser(repository.CreateUserArgs{Name: userName, Password: "testtesttesttest", Role: role.User}) require.NoError(t, err) return u } diff --git a/router/v3/sessions.go b/router/v3/sessions.go index a1e2091d2..af37c15a5 100644 --- a/router/v3/sessions.go +++ b/router/v3/sessions.go @@ -8,7 +8,6 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" - "github.com/traPtitech/traQ/router/sessions" "github.com/traPtitech/traQ/utils/validator" "go.uber.org/zap" "net/http" @@ -61,12 +60,7 @@ func (h *Handlers) Login(c echo.Context) error { } h.L(c).Info("an api login attempt succeeded", zap.String("username", req.Name)) - sess, err := sessions.Get(c.Response(), c.Request(), true) - if err != nil { - return herror.InternalServerError(err) - } - - if err := sess.SetUser(user.GetID()); err != nil { + if _, err := h.SessStore.RenewSession(c, user.GetID()); err != nil { return herror.InternalServerError(err) } @@ -78,29 +72,20 @@ func (h *Handlers) Login(c echo.Context) error { // Logout POST /logout func (h *Handlers) Logout(c echo.Context) error { - sess, err := sessions.Get(c.Response(), c.Request(), false) + sess, err := h.SessStore.GetSession(c, false) if err != nil { return herror.InternalServerError(err) } - if sess != nil { - if isTrue(c.QueryParam("all")) { - uid := sess.GetUserID() - if uid == uuid.Nil { - if err := sess.Destroy(c.Response(), c.Request()); err != nil { - return herror.InternalServerError(err) - } - } else { - if err := sessions.DestroyByUserID(uid); err != nil { - return herror.InternalServerError(err) - } - } - } else { - if err := sess.Destroy(c.Response(), c.Request()); err != nil { - return herror.InternalServerError(err) - } + if sess != nil && isTrue(c.QueryParam("all")) && sess.LoggedIn() { + if err := h.SessStore.RevokeSessionsByUserID(sess.UserID()); err != nil { + return herror.InternalServerError(err) } } + if err := h.SessStore.RevokeSession(c); err != nil { + return herror.InternalServerError(err) + } + if redirect := c.QueryParam("redirect"); len(redirect) > 0 { return c.Redirect(http.StatusFound, redirect) } @@ -111,28 +96,21 @@ func (h *Handlers) Logout(c echo.Context) error { func (h *Handlers) GetMySessions(c echo.Context) error { userID := getRequestUserID(c) - ses, err := sessions.GetByUserID(userID) + ses, err := h.SessStore.GetSessionsByUserID(userID) if err != nil { return herror.InternalServerError(err) } type response struct { - ID string `json:"id"` - IP string `json:"ip"` - UA string `json:"ua"` - LastAccess time.Time `json:"lastAccess"` - IssuedAt time.Time `json:"issuedAt"` + ID uuid.UUID `json:"id"` + IssuedAt time.Time `json:"issuedAt"` } res := make([]response, len(ses)) for k, v := range ses { - referenceID, created, lastAccess, lastIP, lastUserAgent := v.GetSessionInfo() res[k] = response{ - ID: referenceID.String(), - IP: lastIP, - UA: lastUserAgent, - LastAccess: lastAccess, - IssuedAt: created, + ID: v.RefID(), + IssuedAt: v.CreatedAt(), } } @@ -141,11 +119,9 @@ func (h *Handlers) GetMySessions(c echo.Context) error { // RevokeMySession DELETE /users/me/sessions/:referenceID func (h *Handlers) RevokeMySession(c echo.Context) error { - userID := getRequestUserID(c) referenceID := getParamAsUUID(c, consts.ParamReferenceID) - err := sessions.DestroyByReferenceID(userID, referenceID) - if err != nil { + if err := h.SessStore.RevokeSessionByRefID(referenceID); err != nil { return herror.InternalServerError(err) } diff --git a/router/v3/users.go b/router/v3/users.go index 15aa6ef5d..a23998d14 100644 --- a/router/v3/users.go +++ b/router/v3/users.go @@ -174,7 +174,7 @@ func (h *Handlers) PutMyPassword(c echo.Context) error { return herror.Unauthorized("password is wrong") } - return utils.ChangeUserPassword(c, h.Repo, user.GetID(), req.NewPassword) + return utils.ChangeUserPassword(c, h.Repo, h.SessStore, user.GetID(), req.NewPassword) } // GetMyQRCode GET /users/me/qr-code @@ -306,7 +306,7 @@ func (h *Handlers) ChangeUserPassword(c echo.Context) error { if err := bindAndValidate(c, &req); err != nil { return err } - return utils.ChangeUserPassword(c, h.Repo, getParamAsUUID(c, consts.ParamUserID), req.NewPassword) + return utils.ChangeUserPassword(c, h.Repo, h.SessStore, getParamAsUUID(c, consts.ParamUserID), req.NewPassword) } // GetUser GET /users/:userID diff --git a/router/v3/users_test.go b/router/v3/users_test.go index 4d897f9e6..e90d8ab26 100644 --- a/router/v3/users_test.go +++ b/router/v3/users_test.go @@ -4,7 +4,7 @@ import ( "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/traPtitech/traQ/router/sessions" + "github.com/traPtitech/traQ/router/session" "net/http" "strings" "testing" @@ -13,12 +13,12 @@ import ( func TestHandlers_PutMyPassword(t *testing.T) { t.Parallel() path := "/api/v3/users/me/password" - repo, server := Setup(t, common) - commonSession := S(t, CreateUser(t, repo, rand).GetID()) + env := Setup(t, common) + commonSession := env.S(t, env.CreateUser(t, rand).GetID()) t.Run("NotLoggedIn", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). Expect(). Status(http.StatusUnauthorized) @@ -26,9 +26,9 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("invalid body", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). - WithCookie(sessions.CookieName, commonSession). + WithCookie(session.CookieName, commonSession). WithJSON(echo.Map{"password": 111, "newPassword": false}). Expect(). Status(http.StatusBadRequest) @@ -36,9 +36,9 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("invalid password1", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). - WithCookie(sessions.CookieName, commonSession). + WithCookie(session.CookieName, commonSession). WithJSON(echo.Map{"password": "test", "newPassword": "a"}). Expect(). Status(http.StatusBadRequest) @@ -46,9 +46,9 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("invalid password2", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). - WithCookie(sessions.CookieName, commonSession). + WithCookie(session.CookieName, commonSession). WithJSON(echo.Map{"password": "test", "newPassword": "アイウエオ"}). Expect(). Status(http.StatusBadRequest) @@ -56,9 +56,9 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("invalid password3", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). - WithCookie(sessions.CookieName, commonSession). + WithCookie(session.CookieName, commonSession). WithJSON(echo.Map{"password": "test", "newPassword": strings.Repeat("a", 33)}). Expect(). Status(http.StatusBadRequest) @@ -66,9 +66,9 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("wrong password", func(t *testing.T) { t.Parallel() - e := R(t, server) + e := env.R(t) e.PUT(path). - WithCookie(sessions.CookieName, commonSession). + WithCookie(session.CookieName, commonSession). WithJSON(echo.Map{"password": "wrong password", "newPassword": strings.Repeat("a", 20)}). Expect(). Status(http.StatusUnauthorized) @@ -76,17 +76,17 @@ func TestHandlers_PutMyPassword(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - user := CreateUser(t, repo, rand) + user := env.CreateUser(t, rand) - e := R(t, server) + e := env.R(t) new := strings.Repeat("a", 20) e.PUT(path). - WithCookie(sessions.CookieName, S(t, user.GetID())). + WithCookie(session.CookieName, env.S(t, user.GetID())). WithJSON(echo.Map{"password": "testtesttesttest", "newPassword": new}). Expect(). Status(http.StatusNoContent) - u, err := repo.GetUser(user.GetID(), false) + u, err := env.Repository.GetUser(user.GetID(), false) require.NoError(t, err) assert.NoError(t, u.Authenticate(new)) }) diff --git a/router/wire_gen.go b/router/wire_gen.go index 66e69ed19..07a973190 100644 --- a/router/wire_gen.go +++ b/router/wire_gen.go @@ -6,9 +6,11 @@ package router import ( + "github.com/jinzhu/gorm" "github.com/leandro-lugaresi/hub" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/oauth2" + "github.com/traPtitech/traQ/router/session" "github.com/traPtitech/traQ/router/v1" "github.com/traPtitech/traQ/router/v3" "github.com/traPtitech/traQ/service" @@ -17,8 +19,9 @@ import ( // Injectors from router_wire.go: -func newRouter(hub2 *hub.Hub, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *Router { +func newRouter(hub2 *hub.Hub, db *gorm.DB, repo repository.Repository, ss *service.Services, logger *zap.Logger, config *Config) *Router { echo := newEcho(logger, config, repo) + sessionStore := session.NewGormStore(db) rbac := ss.RBAC streamer := ss.SSE onlineCounter := ss.OnlineCounter @@ -35,34 +38,38 @@ func newRouter(hub2 *hub.Hub, repo repository.Repository, ss *service.Services, VM: manager, HeartBeats: heartbeatManager, Imaging: processor, + SessStore: sessionStore, } wsStreamer := ss.WS webrtcv3Manager := ss.WebRTCv3 v3Config := provideV3Config(config) v3Handlers := &v3.Handlers{ - RBAC: rbac, - Repo: repo, - WS: wsStreamer, - Hub: hub2, - Logger: logger, - OC: onlineCounter, - VM: manager, - WebRTC: webrtcv3Manager, - Imaging: processor, - Config: v3Config, + RBAC: rbac, + Repo: repo, + WS: wsStreamer, + Hub: hub2, + Logger: logger, + OC: onlineCounter, + VM: manager, + WebRTC: webrtcv3Manager, + Imaging: processor, + SessStore: sessionStore, + Config: v3Config, } oauth2Config := provideOAuth2Config(config) handler := &oauth2.Handler{ - RBAC: rbac, - Repo: repo, - Logger: logger, - Config: oauth2Config, + RBAC: rbac, + Repo: repo, + Logger: logger, + SessStore: sessionStore, + Config: oauth2Config, } router := &Router{ - e: echo, - v1: handlers, - v3: v3Handlers, - oauth2: handler, + e: echo, + sessStore: sessionStore, + v1: handlers, + v3: v3Handlers, + oauth2: handler, } return router } diff --git a/service/bot/processor.go b/service/bot/processor.go index 948cbc6c0..2b3bdaed6 100644 --- a/service/bot/processor.go +++ b/service/bot/processor.go @@ -13,7 +13,6 @@ import ( "github.com/traPtitech/traQ/service/bot/event" "go.uber.org/zap" "net/http" - "strconv" "sync" "time" ) @@ -29,7 +28,7 @@ const ( var eventSendCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "traq", Name: "bot_event_send_count_total", -}, []string{"bot_id", "code"}) +}, []string{"bot_id", "status"}) // Processor ボットプロセッサー type Processor struct { @@ -82,7 +81,7 @@ func (p *Processor) sendEvent(b *model.Bot, event event.Type, body []byte) (ok b stop := time.Now() if err != nil { - eventSendCounter.WithLabelValues(b.ID.String(), "-1").Inc() + eventSendCounter.WithLabelValues(b.ID.String(), "ne").Inc() if err := p.repo.WriteBotEventLog(&model.BotEventLog{ RequestID: reqID, BotID: b.ID, @@ -99,7 +98,12 @@ func (p *Processor) sendEvent(b *model.Bot, event event.Type, body []byte) (ok b } _ = res.Body.Close() - eventSendCounter.WithLabelValues(b.ID.String(), strconv.Itoa(res.StatusCode)).Inc() + if res.StatusCode == http.StatusNoContent { + eventSendCounter.WithLabelValues(b.ID.String(), "ok").Inc() + } else { + eventSendCounter.WithLabelValues(b.ID.String(), "ng").Inc() + } + if err := p.repo.WriteBotEventLog(&model.BotEventLog{ RequestID: reqID, BotID: b.ID,