diff --git a/Makefile b/Makefile index 1d2b21ab9..f3b9c68eb 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,8 @@ traQ: $(SOURCES) ## Build traQ binary .PHONY: init init: ## Download and install go mod dependencies go mod download - go install github.com/google/wire/cmd/wire - go install github.com/golang/mock/mockgen + go install github.com/google/wire/cmd/wire@latest + go install github.com/golang/mock/mockgen@latest .PHONY: genkey genkey: ## Generate dev keys diff --git a/cmd/file.go b/cmd/file.go index 62ae5d49d..ccb41fb92 100644 --- a/cmd/file.go +++ b/cmd/file.go @@ -69,7 +69,7 @@ func filePruneCommand() *cobra.Command { } // Repository - repo, err := gorm.NewGormRepository(db, hub.New(), logger) + repo, _, err := gorm.NewGormRepository(db, hub.New(), logger, false) if err != nil { logger.Fatal("failed to initialize repository", zap.Error(err)) } @@ -368,7 +368,7 @@ func genGroupImages() *cobra.Command { } // Repository - repo, err := gorm.NewGormRepository(db, hub.New(), logger) + repo, _, err := gorm.NewGormRepository(db, hub.New(), logger, false) if err != nil { logger.Fatal("failed to initialize repository", zap.Error(err)) } diff --git a/cmd/serve.go b/cmd/serve.go index 2dd8bec6d..e400cbe4a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -75,20 +75,12 @@ func serveCommand() *cobra.Command { // Repository logger.Info("setting up repository...") - repo, err := gorm.NewGormRepository(engine, hub, logger) + repo, init, err := gorm.NewGormRepository(engine, hub, logger, true) if err != nil { logger.Fatal("failed to initialize repository", zap.Error(err)) } logger.Info("repository was set up") - // Repository Sync - logger.Info("syncing repository...") - init, err := repo.Sync() - if err != nil { - logger.Fatal("failed to sync repository", zap.Error(err)) - } - logger.Info("repository was synced") - // JWT for QRCode if priv := c.JWT.Keys.Private; priv != "" { privRaw, err := ioutil.ReadFile(priv) diff --git a/cmd/stamp.go b/cmd/stamp.go index 5b838e5a2..b32d9ac50 100644 --- a/cmd/stamp.go +++ b/cmd/stamp.go @@ -58,7 +58,7 @@ func stampInstallEmojisCommand() *cobra.Command { } // Repository - repo, err := gorm.NewGormRepository(db, hub.New(), logger) + repo, _, err := gorm.NewGormRepository(db, hub.New(), logger, false) if err != nil { logger.Fatal("failed to initialize repository", zap.Error(err)) } diff --git a/repository/gorm/repository.go b/repository/gorm/repository.go index 02a7549a3..33f176fb2 100644 --- a/repository/gorm/repository.go +++ b/repository/gorm/repository.go @@ -6,7 +6,6 @@ import ( "gorm.io/gorm" "github.com/traPtitech/traQ/migration" - "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" ) @@ -18,28 +17,19 @@ type Repository struct { stamps *stampRepository } -// Sync implements Repository interface. -func (repo *Repository) Sync() (init bool, err error) { - if init, err = migration.Migrate(repo.db); err != nil { - return false, err - } - - // スタンプをキャッシュ - var stamps []*model.Stamp - if err := repo.db.Find(&stamps).Error; err != nil { - return false, err - } - repo.stamps = makeStampRepository(stamps) - - return -} - -// NewGormRepository リポジトリ実装を初期化して生成します -func NewGormRepository(db *gorm.DB, hub *hub.Hub, logger *zap.Logger) (repository.Repository, error) { - repo := &Repository{ +// NewGormRepository リポジトリ実装を初期化して生成します。 +// スキーマが初期化された場合、init: true を返します。 +func NewGormRepository(db *gorm.DB, hub *hub.Hub, logger *zap.Logger, doMigration bool) (repo repository.Repository, init bool, err error) { + repo = &Repository{ db: db, hub: hub, logger: logger.Named("repository"), + stamps: makeStampRepository(db), } - return repo, nil + if doMigration { + if init, err = migration.Migrate(db); err != nil { + return nil, false, err + } + } + return } diff --git a/repository/gorm/repository_test.go b/repository/gorm/repository_test.go index b6697421b..fad596a60 100644 --- a/repository/gorm/repository_test.go +++ b/repository/gorm/repository_test.go @@ -78,13 +78,10 @@ func TestMain(m *testing.M) { panic(err) } - repo, err := NewGormRepository(engine, hub.New(), zap.NewNop()) + repo, _, err := NewGormRepository(engine, hub.New(), zap.NewNop(), true) if err != nil { panic(err) } - if _, err := repo.Sync(); err != nil { - panic(err) - } repositories[key] = repo.(*Repository) } diff --git a/repository/gorm/stamp.go b/repository/gorm/stamp.go index 312b905a2..eed891276 100644 --- a/repository/gorm/stamp.go +++ b/repository/gorm/stamp.go @@ -1,13 +1,15 @@ package gorm import ( - "sync" + "context" + "errors" + "sort" "time" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" - jsoniter "github.com/json-iterator/go" "github.com/leandro-lugaresi/hub" + "github.com/motoki317/sc" "gorm.io/gorm" "github.com/traPtitech/traQ/event" @@ -18,93 +20,92 @@ import ( ) type stampRepository struct { - stamps map[uuid.UUID]*model.Stamp - allJSON []byte - unicodeJSON []byte - originalJSON []byte - updatedAt time.Time - sync.RWMutex + stamps *sc.Cache[struct{}, map[uuid.UUID]*model.Stamp] + perType *sc.Cache[repository.StampType, []*model.Stamp] } -func makeStampRepository(stamps []*model.Stamp) *stampRepository { - r := &stampRepository{ - stamps: make(map[uuid.UUID]*model.Stamp, len(stamps)), - updatedAt: time.Now(), - } - for _, s := range stamps { - r.stamps[s.ID] = s - } - - r.regenerateJSON() +func makeStampRepository(db *gorm.DB) *stampRepository { + // Lazy load + r := &stampRepository{} + r.stamps = sc.NewMust(r.loadFunc(db), 365*24*time.Hour, 365*24*time.Hour) + r.perType = sc.NewMust(r.filterFunc(), 365*24*time.Hour, 365*24*time.Hour) return r } -func (r *stampRepository) add(s *model.Stamp) { - r.stamps[s.ID] = s - r.updatedAt = time.Now() - r.regenerateJSON() -} - -func (r *stampRepository) update(s *model.Stamp) { - r.stamps[s.ID] = s - r.updatedAt = time.Now() - r.regenerateJSON() +func (r *stampRepository) loadFunc(db *gorm.DB) func(context.Context, struct{}) (map[uuid.UUID]*model.Stamp, error) { + return func(_ context.Context, _ struct{}) (map[uuid.UUID]*model.Stamp, error) { + var stamps []*model.Stamp + if err := db.Find(&stamps).Error; err != nil { + return nil, err + } + stampsMap := make(map[uuid.UUID]*model.Stamp, len(stamps)) + for _, s := range stamps { + stampsMap[s.ID] = s + } + return stampsMap, nil + } } -func (r *stampRepository) delete(id uuid.UUID) { - delete(r.stamps, id) - r.updatedAt = time.Now() - r.regenerateJSON() -} +func (r *stampRepository) filterFunc() func(_ context.Context, stampType repository.StampType) ([]*model.Stamp, error) { + return func(ctx context.Context, stampType repository.StampType) ([]*model.Stamp, error) { + stamps, err := r.stamps.Get(ctx, struct{}{}) + if err != nil { + return nil, err + } + arr := make([]*model.Stamp, 0, len(stamps)) -func (r *stampRepository) regenerateJSON() { - arrOriginal := make([]*model.Stamp, 0, len(r.stamps)) - arrUnicode := make([]*model.Stamp, 0, len(r.stamps)) - arrAll := make([]*model.Stamp, 0, len(r.stamps)) - for _, stamp := range r.stamps { - arrAll = append(arrAll, stamp) - if stamp.IsUnicode { - arrUnicode = append(arrUnicode, stamp) - } else { - arrOriginal = append(arrOriginal, stamp) + switch stampType { + case repository.StampTypeAll: + for _, s := range stamps { + arr = append(arr, s) + } + case repository.StampTypeUnicode: + for _, s := range stamps { + if s.IsUnicode { + arr = append(arr, s) + } + } + case repository.StampTypeOriginal: + for _, s := range stamps { + if !s.IsUnicode { + arr = append(arr, s) + } + } + default: + return nil, errors.New("unknown stamp type") } - } - b, err := jsoniter.ConfigFastest.Marshal(arrUnicode) - if err != nil { - panic(err) + sort.Slice(arr, func(i, j int) bool { return arr[i].ID.String() < arr[j].ID.String() }) + return arr, nil } - r.unicodeJSON = b +} - b, err = jsoniter.ConfigFastest.Marshal(arrOriginal) - if err != nil { - panic(err) - } - r.originalJSON = b +// Purge purges stamp cache. +func (r *stampRepository) Purge() { + r.stamps.Purge() + r.perType.Purge() +} - b, err = jsoniter.ConfigFastest.Marshal(arrAll) +func (r *stampRepository) GetStamp(id uuid.UUID) (s *model.Stamp, ok bool, err error) { + stamps, err := r.stamps.Get(context.Background(), struct{}{}) if err != nil { - panic(err) + return nil, false, err } - r.allJSON = b -} - -func (r *stampRepository) GetStamp(id uuid.UUID) (s *model.Stamp, ok bool) { - r.RLock() - defer r.RUnlock() - s, ok = r.stamps[id] + s, ok = stamps[id] return } -func (r *stampRepository) CheckIDs(ids []uuid.UUID) bool { - r.RLock() - defer r.RUnlock() +func (r *stampRepository) CheckIDs(ids []uuid.UUID) (ok bool, err error) { + stamps, err := r.stamps.Get(context.Background(), struct{}{}) + if err != nil { + return false, err + } for _, id := range ids { - if _, ok := r.stamps[id]; !ok { - return false + if _, ok := stamps[id]; !ok { + return false, nil } } - return true + return true, nil } // CreateStamp implements StampRepository interface. @@ -117,11 +118,6 @@ func (repo *Repository) CreateStamp(args repository.CreateStampArgs) (s *model.S IsUnicode: args.IsUnicode, } - if repo.stamps != nil { - repo.stamps.Lock() - defer repo.stamps.Unlock() - } - err = repo.db.Transaction(func(tx *gorm.DB) error { // 名前チェック if err := vd.Validate(stamp.Name, validator.StampNameRuleRequired...); err != nil { @@ -149,9 +145,7 @@ func (repo *Repository) CreateStamp(args repository.CreateStampArgs) (s *model.S return nil, err } - if repo.stamps != nil { - repo.stamps.add(stamp) - } + repo.stamps.Purge() repo.hub.Publish(hub.Message{ Name: event.StampCreated, @@ -169,11 +163,6 @@ func (repo *Repository) UpdateStamp(id uuid.UUID, args repository.UpdateStampArg return repository.ErrNilID } - if repo.stamps != nil { - repo.stamps.Lock() - defer repo.stamps.Unlock() - } - var s model.Stamp changes := map[string]interface{}{} err := repo.db.Transaction(func(tx *gorm.DB) error { @@ -220,9 +209,7 @@ func (repo *Repository) UpdateStamp(id uuid.UUID, args repository.UpdateStampArg return err } if len(changes) > 0 { - if repo.stamps != nil { - repo.stamps.update(&s) - } + repo.stamps.Purge() repo.hub.Publish(hub.Message{ Name: event.StampUpdated, Fields: hub.Fields{ @@ -239,18 +226,14 @@ func (repo *Repository) GetStamp(id uuid.UUID) (s *model.Stamp, err error) { return nil, repository.ErrNotFound } - if repo.stamps != nil { - if s, ok := repo.stamps.GetStamp(id); ok { - return s, nil - } - return nil, repository.ErrNotFound + s, ok, err := repo.stamps.GetStamp(id) + if err != nil { + return nil, err } - - s = &model.Stamp{} - if err := repo.db.First(s, &model.Stamp{ID: id}).Error; err != nil { - return nil, convertError(err) + if ok { + return s, nil } - return s, nil + return nil, repository.ErrNotFound } // GetStampByName implements StampRepository interface. @@ -271,19 +254,12 @@ func (repo *Repository) DeleteStamp(id uuid.UUID) (err error) { return repository.ErrNilID } - if repo.stamps != nil { - repo.stamps.Lock() - defer repo.stamps.Unlock() - } - result := repo.db.Delete(&model.Stamp{ID: id}) if result.Error != nil { return result.Error } if result.RowsAffected > 0 { - if repo.stamps != nil { - repo.stamps.delete(id) - } + repo.stamps.Purge() repo.hub.Publish(hub.Message{ Name: event.StampDeleted, Fields: hub.Fields{ @@ -297,38 +273,7 @@ func (repo *Repository) DeleteStamp(id uuid.UUID) (err error) { // GetAllStamps implements StampRepository interface. func (repo *Repository) GetAllStamps(stampType repository.StampType) (stamps []*model.Stamp, err error) { - stamps = make([]*model.Stamp, 0) - tx := repo.db - switch stampType { - case repository.StampTypeUnicode: - tx = tx.Where("is_unicode = TRUE") - case repository.StampTypeOriginal: - tx = tx.Where("is_unicode = FALSE") - } - return stamps, tx.Find(&stamps).Error -} - -// GetStampsJSON implements StampRepository interface. -func (repo *Repository) GetStampsJSON(stampType repository.StampType) ([]byte, time.Time, error) { - if repo.stamps != nil { - repo.stamps.RLock() - defer repo.stamps.RUnlock() - switch stampType { - case repository.StampTypeUnicode: - return repo.stamps.unicodeJSON, repo.stamps.updatedAt, nil - case repository.StampTypeOriginal: - return repo.stamps.originalJSON, repo.stamps.updatedAt, nil - default: - return repo.stamps.allJSON, repo.stamps.updatedAt, nil - } - } - - stamps, err := repo.GetAllStamps(stampType) - if err != nil { - return nil, time.Time{}, err - } - b, err := jsoniter.ConfigFastest.Marshal(stamps) - return b, time.Now(), err + return repo.stamps.perType.Get(context.Background(), stampType) } // StampExists implements StampRepository interface. @@ -337,32 +282,23 @@ func (repo *Repository) StampExists(id uuid.UUID) (bool, error) { return false, nil } - if repo.stamps != nil { - _, ok := repo.stamps.GetStamp(id) - return ok, nil + _, ok, err := repo.stamps.GetStamp(id) + if err != nil { + return false, err } - return gormutil.RecordExists(repo.db, &model.Stamp{ID: id}) + return ok, nil } // ExistStamps implements StampPaletteRepository interface. func (repo *Repository) ExistStamps(stampIDs []uuid.UUID) (err error) { - if repo.stamps != nil { - if repo.stamps.CheckIDs(stampIDs) { - return nil - } - return repository.ArgError("stamp", "stamp is not found") - } - - num, err := gormutil.Count(repo.db. - Table("stamps"). - Where("id IN (?)", stampIDs)) + ok, err := repo.stamps.CheckIDs(stampIDs) if err != nil { return err } - if len(stampIDs) != int(num) { - err = repository.ArgError("stamp", "stamp is not found") + if ok { + return nil } - return + return repository.ArgError("stamp", "stamp is not found") } // GetUserStampHistory implements StampRepository interface. diff --git a/repository/repository.go b/repository/repository.go index ecc416edf..10c4b8688 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -2,11 +2,6 @@ package repository // Repository データリポジトリ type Repository interface { - // Sync DBなどとデータを同期します - // - // スキーマが初期化された場合、trueを返します。 - // DBによるエラーを返すことがあります。 - Sync() (bool, error) UserRepository UserGroupRepository UserSettingsRepository diff --git a/repository/stamp.go b/repository/stamp.go index 7b14d9f6a..cb44a276f 100644 --- a/repository/stamp.go +++ b/repository/stamp.go @@ -87,14 +87,9 @@ type StampRepository interface { DeleteStamp(id uuid.UUID) (err error) // GetAllStamps 全てのスタンプを取得します // - // 成功した場合、スタンプの配列とnilを返します。 + // 成功した場合、スタンプのIDでソートされた配列とnilを返します。 // DBによるエラーを返すことがあります。 GetAllStamps(stampType StampType) (stamps []*model.Stamp, err error) - // GetStampsJSON スタンプ一覧のJSON文字列を取得します - // - // 成功した場合、JSONの[]byte表現とnilを返します。 - // DBによるエラーを返すことがあります。 - GetStampsJSON(stampType StampType) ([]byte, time.Time, error) // StampExists 指定したIDのスタンプが存在するかどうかを返します // // 存在する場合、trueとnilを返します。 diff --git a/router/extension/precond.go b/router/extension/precond.go index 7f20df67a..3ded14f9c 100644 --- a/router/extension/precond.go +++ b/router/extension/precond.go @@ -188,7 +188,7 @@ func CheckPreconditions(c echo.Context, modtime time.Time) (done bool, err error return false, nil } -// ServeJSONWithETag Etagを付与してJSONを返します、304を返せるときは304を返します +// ServeJSONWithETag Etagを付与してJSONを返します。304を返せるときは304を返します。 func ServeJSONWithETag(c echo.Context, i interface{}) error { j := jsoniter.Config{ EscapeHTML: false, @@ -209,12 +209,17 @@ func ServeJSONWithETag(c echo.Context, i interface{}) error { return err } - md5Res := md5.Sum(b) + return ServeWithETag(c, echo.MIMEApplicationJSONCharsetUTF8, b) +} + +// ServeWithETag Etagを付与して返します。304を返せるときは304を返します。 +func ServeWithETag(c echo.Context, contentType string, bytes []byte) error { + md5Res := md5.Sum(bytes) etag := hex.EncodeToString(md5Res[:]) c.Response().Header().Set(consts.HeaderETag, "\""+etag+"\"") if done, err := CheckPreconditions(c, time.Time{}); done { return err } - return c.JSONBlob(http.StatusOK, b) + return c.Blob(http.StatusOK, contentType, bytes) } diff --git a/router/oauth2/oauth2_test.go b/router/oauth2/oauth2_test.go index 6c5ef07ba..e6a6a7f9e 100644 --- a/router/oauth2/oauth2_test.go +++ b/router/oauth2/oauth2_test.go @@ -82,13 +82,10 @@ func TestMain(m *testing.M) { env.SessStore = session.NewMemorySessionStore() // テスト用リポジトリ作成 - repo, err := gorm2.NewGormRepository(engine, env.Hub, zap.NewNop()) + repo, _, err := gorm2.NewGormRepository(engine, env.Hub, zap.NewNop(), true) if err != nil { panic(err) } - if _, err := repo.Sync(); err != nil { - panic(err) - } env.Repository = repo // テスト用サーバー作成 diff --git a/router/router_wire.go b/router/router_wire.go index d52dc6c47..2c6cb16b2 100644 --- a/router/router_wire.go +++ b/router/router_wire.go @@ -25,6 +25,7 @@ func newRouter(hub *hub.Hub, db *gorm.DB, repo repository.Repository, ss *servic newEcho, utils.NewReplaceMapper, message.NewReplacer, + v1.NewEmojiCache, provideOAuth2Config, provideV3Config, session.NewGormStore, diff --git a/router/v1/public.go b/router/v1/public.go index ecf48e812..e07131a25 100644 --- a/router/v1/public.go +++ b/router/v1/public.go @@ -2,12 +2,14 @@ package v1 import ( "bytes" + "context" "fmt" "net/http" "strconv" "time" "github.com/labstack/echo/v4" + "github.com/motoki317/sc" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" @@ -58,98 +60,20 @@ func (h *Handlers) GetPublicUserIcon(c echo.Context) error { // GetPublicEmojiJSON GET /public/emoji.json func (h *Handlers) GetPublicEmojiJSON(c echo.Context) error { - extension.SetLastModified(c, h.emojiJSONTime) - if done, _ := extension.CheckPreconditions(c, h.emojiJSONTime); done { - return nil - } - - // キャッシュ確認 - h.emojiJSONCacheLock.RLock() - if h.emojiJSONCache.Len() > 0 { - defer h.emojiJSONCacheLock.RUnlock() - return c.JSONBlob(http.StatusOK, h.emojiJSONCache.Bytes()) - } - h.emojiJSONCacheLock.RUnlock() - - // 生成 - h.emojiJSONCacheLock.Lock() - defer h.emojiJSONCacheLock.Unlock() - - if h.emojiJSONCache.Len() > 0 { // リロード - return c.JSONBlob(http.StatusOK, h.emojiJSONCache.Bytes()) - } - - if err := generateEmojiJSON(h.Repo, &h.emojiJSONCache); err != nil { - return herror.InternalServerError(err) - } - h.emojiJSONTime = time.Now() - extension.SetLastModified(c, h.emojiJSONTime) - return c.JSONBlob(http.StatusOK, h.emojiJSONCache.Bytes()) -} - -func generateEmojiJSON(repo repository.StampRepository, buf *bytes.Buffer) error { - stamps, err := repo.GetAllStamps(repository.StampTypeAll) + emojiJSON, err := h.EmojiCache.json.Get(context.Background(), struct{}{}) if err != nil { - return err - } - - resData := make(map[string][]string) - arr := make([]string, len(stamps)) - for i, stamp := range stamps { - arr[i] = stamp.Name + return herror.InternalServerError(err) } - resData["all"] = arr - - buf.Reset() - return json.NewEncoder(buf).Encode(resData) + return extension.ServeWithETag(c, echo.MIMEApplicationJSONCharsetUTF8, emojiJSON) } // GetPublicEmojiCSS GET /public/emoji.css func (h *Handlers) GetPublicEmojiCSS(c echo.Context) error { - extension.SetLastModified(c, h.emojiCSSTime) - if done, _ := extension.CheckPreconditions(c, h.emojiCSSTime); done { - return nil - } - - // キャッシュ確認 - h.emojiCSSCacheLock.RLock() - if h.emojiCSSCache.Len() > 0 { - defer h.emojiCSSCacheLock.RUnlock() - return c.Blob(http.StatusOK, "text/css", h.emojiCSSCache.Bytes()) - } - h.emojiCSSCacheLock.RUnlock() - - // 生成 - h.emojiCSSCacheLock.Lock() - defer h.emojiCSSCacheLock.Unlock() - - if h.emojiCSSCache.Len() > 0 { // リロード - return c.Blob(http.StatusOK, "text/css", h.emojiCSSCache.Bytes()) - } - - if err := generateEmojiCSS(h.Repo, &h.emojiCSSCache); err != nil { - return herror.InternalServerError(err) - } - h.emojiCSSTime = time.Now() - extension.SetLastModified(c, h.emojiCSSTime) - return c.Blob(http.StatusOK, "text/css", h.emojiCSSCache.Bytes()) -} - -func generateEmojiCSS(repo repository.StampRepository, buf *bytes.Buffer) error { - stamps, err := repo.GetAllStamps(repository.StampTypeAll) + emojiCSS, err := h.EmojiCache.css.Get(context.Background(), struct{}{}) if err != nil { - return err - } - - buf.Reset() - buf.WriteString(".emoji{display:inline-block;text-indent:999%;white-space:nowrap;overflow:hidden;color:rgba(0,0,0,0);background-size:contain}") - buf.WriteString(".s16{width:16px;height:16px}") - buf.WriteString(".s24{width:24px;height:24px}") - buf.WriteString(".s32{width:32px;height:32px}") - for _, stamp := range stamps { - buf.WriteString(fmt.Sprintf(".emoji.e_%s{background-image:url(/api/1.0/public/emoji/%s)}", stamp.Name, stamp.ID)) + return herror.InternalServerError(err) } - return nil + return extension.ServeWithETag(c, "text/css", emojiCSS) } // GetPublicEmojiImage GET /public/emoji/{stampID} @@ -173,3 +97,63 @@ func (h *Handlers) GetPublicEmojiImage(c echo.Context) error { http.ServeContent(c.Response(), c.Request(), meta.GetFileName(), meta.GetCreatedAt(), file) return nil } + +type EmojiCache struct { + json *sc.Cache[struct{}, []byte] + css *sc.Cache[struct{}, []byte] +} + +func NewEmojiCache(repo repository.Repository) *EmojiCache { + return &EmojiCache{ + json: sc.NewMust(emojiJSONGenerator(repo), 365*24*time.Hour, 365*24*time.Hour), + css: sc.NewMust(emojiCSSGenerator(repo), 365*24*time.Hour, 365*24*time.Hour), + } +} + +// Purge purges cache content. +func (c *EmojiCache) Purge() { + c.json.Purge() + c.css.Purge() +} + +func emojiJSONGenerator(repo repository.Repository) func(_ context.Context, _ struct{}) ([]byte, error) { + return func(_ context.Context, _ struct{}) ([]byte, error) { + stamps, err := repo.GetAllStamps(repository.StampTypeAll) + if err != nil { + return nil, err + } + + stampNames := make([]string, len(stamps)) + for i, stamp := range stamps { + stampNames[i] = stamp.Name + } + + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(map[string][]string{ + "all": stampNames, + }) + if err != nil { + return nil, err + } + return buf.Bytes(), err + } +} + +func emojiCSSGenerator(repo repository.Repository) func(_ context.Context, _ struct{}) ([]byte, error) { + return func(_ context.Context, _ struct{}) ([]byte, error) { + stamps, err := repo.GetAllStamps(repository.StampTypeAll) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + buf.WriteString(".emoji{display:inline-block;text-indent:999%;white-space:nowrap;overflow:hidden;color:rgba(0,0,0,0);background-size:contain}") + buf.WriteString(".s16{width:16px;height:16px}") + buf.WriteString(".s24{width:24px;height:24px}") + buf.WriteString(".s32{width:32px;height:32px}") + for _, stamp := range stamps { + buf.WriteString(fmt.Sprintf(".emoji.e_%s{background-image:url(/api/1.0/public/emoji/%s)}", stamp.Name, stamp.ID)) + } + return buf.Bytes(), nil + } +} diff --git a/router/v1/router.go b/router/v1/router.go index cd706dd15..f11a87745 100644 --- a/router/v1/router.go +++ b/router/v1/router.go @@ -1,11 +1,8 @@ package v1 import ( - "bytes" "encoding/gob" "net/http" - "sync" - "time" "github.com/gofrs/uuid" jsoniter "github.com/json-iterator/go" @@ -47,13 +44,7 @@ type Handlers struct { MessageManager message.Manager FileManager file.Manager Replacer *mutil.Replacer - - emojiJSONCache bytes.Buffer `wire:"-"` - emojiJSONTime time.Time `wire:"-"` - emojiJSONCacheLock sync.RWMutex `wire:"-"` - emojiCSSCache bytes.Buffer `wire:"-"` - emojiCSSTime time.Time `wire:"-"` - emojiCSSCacheLock sync.RWMutex `wire:"-"` + EmojiCache *EmojiCache } // Setup APIルーティングを行います @@ -335,13 +326,7 @@ func (h *Handlers) Setup(e *echo.Group) { func (h *Handlers) stampEventSubscriber(sub hub.Subscription) { for range sub.Receiver { - h.emojiJSONCacheLock.Lock() - h.emojiJSONCache.Reset() - h.emojiJSONCacheLock.Unlock() - - h.emojiCSSCacheLock.Lock() - h.emojiCSSCache.Reset() - h.emojiCSSCacheLock.Unlock() + h.EmojiCache.Purge() } } diff --git a/router/v3/clips.go b/router/v3/clips.go index 779db1074..7f6f88a85 100644 --- a/router/v3/clips.go +++ b/router/v3/clips.go @@ -5,6 +5,7 @@ import ( "strconv" "strings" + "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/service/message" "github.com/traPtitech/traQ/utils/optional" @@ -58,7 +59,6 @@ func (h *Handlers) CreateClipFolder(c echo.Context) error { } return c.JSON(http.StatusCreated, formatClipFolder(cf)) - } // GetClipFolders GET /clip-folders @@ -70,7 +70,7 @@ func (h *Handlers) GetClipFolders(c echo.Context) error { return herror.InternalServerError(err) } - return c.JSON(http.StatusOK, formatClipFolders(cfs)) + return extension.ServeJSONWithETag(c, formatClipFolders(cfs)) } // GetClipFolder GET /clip-folders/:folderID diff --git a/router/v3/responses.go b/router/v3/responses.go index 7e15d0a4d..4dd46d71f 100644 --- a/router/v3/responses.go +++ b/router/v3/responses.go @@ -45,12 +45,8 @@ func formatDMChannels(dmcs map[uuid.UUID]uuid.UUID) []*DMChannel { } sort.Slice(res, func(i, j int) bool { - if res[i].ID.String() == res[j].ID.String() { - return res[i].UserID.String() > res[j].UserID.String() - } - return res[i].ID.String() > res[j].ID.String() + return res[i].ID.String() < res[j].ID.String() }) - return res } @@ -90,6 +86,7 @@ type User struct { UpdatedAt time.Time `json:"updatedAt"` } +// formatUsers ソートされたものを返す func formatUsers(users []model.UserInfo) []User { res := make([]User, len(users)) for i, user := range users { @@ -103,6 +100,10 @@ func formatUsers(users []model.UserInfo) []User { UpdatedAt: user.GetUpdatedAt(), } } + + sort.Slice(res, func(i, j int) bool { + return res[i].ID.String() < res[j].ID.String() + }) return res } @@ -364,6 +365,7 @@ type UserGroupMember struct { Role string `json:"role"` } +// formatUserGroupMembers ソートされたものを返す func formatUserGroupMembers(members []*model.UserGroupMember) []UserGroupMember { arr := make([]UserGroupMember, len(members)) for i, m := range members { @@ -372,14 +374,19 @@ func formatUserGroupMembers(members []*model.UserGroupMember) []UserGroupMember Role: m.Role, } } + + sort.Slice(arr, func(i, j int) bool { return arr[i].ID.String() < arr[j].ID.String() }) return arr } +// formatUserGroupAdmins ソートされたものを返す func formatUserGroupAdmins(admins []*model.UserGroupAdmin) []uuid.UUID { arr := make([]uuid.UUID, len(admins)) for i, m := range admins { arr[i] = m.UserID } + + sort.Slice(arr, func(i, j int) bool { return arr[i].String() < arr[j].String() }) return arr } @@ -410,11 +417,14 @@ func formatUserGroup(g *model.UserGroup) *UserGroup { return ug } +// formatUserGroups ソートされたものを返す func formatUserGroups(gs []*model.UserGroup) []*UserGroup { arr := make([]*UserGroup, len(gs)) for i, g := range gs { arr[i] = formatUserGroup(g) } + + sort.Slice(arr, func(i, j int) bool { return arr[i].ID.String() < arr[j].ID.String() }) return arr } @@ -552,11 +562,13 @@ func formatClipFolder(cf *model.ClipFolder) *ClipFolder { } } +// formatClipFolders ソートされたものを返す func formatClipFolders(cfs []*model.ClipFolder) []*ClipFolder { res := make([]*ClipFolder, len(cfs)) for i, cf := range cfs { res[i] = formatClipFolder(cf) } + sort.Slice(res, func(i, j int) bool { return res[i].ID.String() < res[j].ID.String() }) return res } @@ -602,10 +614,12 @@ func formatStampPalette(cf *model.StampPalette) *StampPalette { } } +// formatStampPalettes ソートされたものを返す func formatStampPalettes(cfs []*model.StampPalette) []*StampPalette { res := make([]*StampPalette, len(cfs)) for i, cf := range cfs { res[i] = formatStampPalette(cf) } + sort.Slice(res, func(i, j int) bool { return res[i].ID.String() < res[j].ID.String() }) return res } diff --git a/router/v3/router_test.go b/router/v3/router_test.go index bbbca74e0..d843d8acc 100644 --- a/router/v3/router_test.go +++ b/router/v3/router_test.go @@ -91,13 +91,11 @@ func TestMain(m *testing.M) { env.SessStore = session.NewMemorySessionStore() // テスト用リポジトリ作成 - repo, err := gorm2.NewGormRepository(engine, env.Hub, l.Named("repository")) + repo, init, err := gorm2.NewGormRepository(engine, env.Hub, l.Named("repository"), true) if err != nil { panic(err) } - if init, err := repo.Sync(); err != nil { - panic(err) - } else if init { + if init { // システムユーザーロール投入 if err := repo.CreateUserRoles(role.SystemRoleModels()...); err != nil { panic(err) diff --git a/router/v3/stamp_palettes.go b/router/v3/stamp_palettes.go index e8752f65d..d9c2ac039 100644 --- a/router/v3/stamp_palettes.go +++ b/router/v3/stamp_palettes.go @@ -3,6 +3,7 @@ package v3 import ( "net/http" + "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/utils/optional" vd "github.com/go-ozzo/ozzo-validation/v4" @@ -23,7 +24,7 @@ func (h *Handlers) GetStampPalettes(c echo.Context) error { return herror.InternalServerError(err) } - return c.JSON(http.StatusOK, formatStampPalettes(palettes)) + return extension.ServeJSONWithETag(c, formatStampPalettes(palettes)) } // CreateStampPaletteRequest POST /stamp-palettes リクエストボディ diff --git a/router/v3/stamps.go b/router/v3/stamps.go index a0fec5dcd..a4f209c7a 100644 --- a/router/v3/stamps.go +++ b/router/v3/stamps.go @@ -65,17 +65,12 @@ func (h *Handlers) GetStamps(c echo.Context) error { stampType = repository.StampTypeOriginal } - b, updatedAt, err := h.Repo.GetStampsJSON(stampType) + stamps, err := h.Repo.GetAllStamps(stampType) if err != nil { return herror.InternalServerError(err) } - c.Response().Header().Set(consts.HeaderCacheControl, "private, max-age=0") // 鮮度を0にして毎回キャッシュ検証させる - extension.SetLastModified(c, updatedAt) - if done, err := extension.CheckPreconditions(c, updatedAt); done { - return err - } - return c.JSONBlob(http.StatusOK, b) + return extension.ServeJSONWithETag(c, stamps) } // CreateStamp POST /stamps diff --git a/router/v3/star.go b/router/v3/star.go index f3cddecbe..f1a943688 100644 --- a/router/v3/star.go +++ b/router/v3/star.go @@ -3,12 +3,14 @@ package v3 import ( "context" "net/http" + "sort" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/router/consts" + "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" "github.com/traPtitech/traQ/router/utils" "github.com/traPtitech/traQ/utils/validator" @@ -23,7 +25,8 @@ func (h *Handlers) GetMyStars(c echo.Context) error { return herror.InternalServerError(err) } - return c.JSON(http.StatusOK, stars) + sort.Slice(stars, func(i, j int) bool { return stars[i].String() < stars[j].String() }) + return extension.ServeJSONWithETag(c, stars) } // PostStarRequest POST /users/me/stars リクエストボディ diff --git a/router/v3/user_groups.go b/router/v3/user_groups.go index 94be0889e..bea12fa7d 100644 --- a/router/v3/user_groups.go +++ b/router/v3/user_groups.go @@ -10,6 +10,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" + "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" "github.com/traPtitech/traQ/router/utils" file2 "github.com/traPtitech/traQ/service/file" @@ -24,7 +25,7 @@ func (h *Handlers) GetUserGroups(c echo.Context) error { if err != nil { return herror.InternalServerError(err) } - return c.JSON(http.StatusOK, formatUserGroups(gs)) + return extension.ServeJSONWithETag(c, formatUserGroups(gs)) } // PostUserGroupRequest POST /groups リクエストボディ diff --git a/router/v3/users.go b/router/v3/users.go index fe7e4078f..b45f896f4 100644 --- a/router/v3/users.go +++ b/router/v3/users.go @@ -3,6 +3,7 @@ package v3 import ( "context" "net/http" + "sort" "time" vd "github.com/go-ozzo/ozzo-validation/v4" @@ -14,6 +15,7 @@ import ( "github.com/traPtitech/traQ/model" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" + "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" "github.com/traPtitech/traQ/router/utils" "github.com/traPtitech/traQ/service/channel" @@ -42,7 +44,7 @@ func (h *Handlers) GetUsers(c echo.Context) error { if err != nil { return herror.InternalServerError(err) } - return c.JSON(http.StatusOK, formatUsers(users)) + return extension.ServeJSONWithETag(c, formatUsers(users)) } // PostUserRequest POST /users リクエストボディ @@ -399,8 +401,9 @@ func (h *Handlers) GetMyChannelSubscriptions(c echo.Context) error { for i, subscription := range subscriptions { result[i] = response{ChannelID: subscription.ChannelID, Level: subscription.GetLevel().Int()} } + sort.Slice(result, func(i, j int) bool { return result[i].ChannelID.String() < result[j].ChannelID.String() }) - return c.JSON(http.StatusOK, result) + return extension.ServeJSONWithETag(c, result) } // PutChannelSubscribeLevelRequest PUT /users/me/subscriptions/:channelID リクエストボディ diff --git a/router/wire_gen.go b/router/wire_gen.go index ff75305a8..78d67fa64 100644 --- a/router/wire_gen.go +++ b/router/wire_gen.go @@ -31,6 +31,7 @@ func newRouter(hub2 *hub.Hub, db *gorm.DB, repo repository.Repository, ss *servi fileManager := ss.FileManager replaceMapper := utils.NewReplaceMapper(repo, manager) replacer := message.NewReplacer(replaceMapper) + emojiCache := v1.NewEmojiCache(repo) handlers := &v1.Handlers{ RBAC: rbac, Repo: repo, @@ -41,6 +42,7 @@ func newRouter(hub2 *hub.Hub, db *gorm.DB, repo repository.Repository, ss *servi MessageManager: messageManager, FileManager: fileManager, Replacer: replacer, + EmojiCache: emojiCache, } streamer := ss.WS wsStreamer := ss.BotWS diff --git a/testutils/empty_test_repository.go b/testutils/empty_test_repository.go index bc7039327..3bd1f1656 100644 --- a/testutils/empty_test_repository.go +++ b/testutils/empty_test_repository.go @@ -25,7 +25,3 @@ type EmptyTestRepository struct { repository.ClipRepository repository.OgpCacheRepository } - -func (*EmptyTestRepository) Sync() (init bool, err error) { - return false, nil -}