diff --git a/doc/paths_user.yaml b/doc/paths_user.yaml index a9ae86b..0e700eb 100644 --- a/doc/paths_user.yaml +++ b/doc/paths_user.yaml @@ -76,6 +76,7 @@ parameters: - $ref: '#/components/parameters/HeaderAuthorization' - $ref: "#/components/parameters/PathUID" + - $ref: "#/components/parameters/QueryProduct" responses: '200': $ref: '#/components/responses/UserRes' diff --git a/src/api/label_test.go b/src/api/label_test.go index 28eac13..73efffc 100644 --- a/src/api/label_test.go +++ b/src/api/label_test.go @@ -689,11 +689,15 @@ func TestLabelAPIs(t *testing.T) { label, err := createLabel(tt, product.Name) assert.Nil(t, err) + label2, err := createLabel(tt, product.Name) + assert.Nil(t, err) + users, err := createUsers(tt, 1) assert.Nil(t, err) user := users[0] var rule tpl.LabelRuleInfo + var rule2 tpl.LabelRuleInfo t.Run(`"POST /v1/products/:product/labels/:label/rules" should work`, func(t *testing.T) { assert := assert.New(t) @@ -764,6 +768,81 @@ func TestLabelAPIs(t *testing.T) { assert.Equal(label.Name, data.Label) }) + t.Run(`"POST /v1/products/:product/labels/:label/rules" label2 should work`, func(t *testing.T) { + assert := assert.New(t) + res, err := request.Post(fmt.Sprintf("%s/v1/products/%s/labels/%s/rules", tt.Host, product.Name, label2.Name)). + Set("Content-Type", "application/json"). + Send(map[string]interface{}{ + "kind": "userPercent", + "rule": map[string]interface{}{ + "value": 100, + }, + }). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + text, err := res.Text() + assert.Nil(err) + assert.True(strings.Contains(text, `"rule":{"value":100}`)) + assert.False(strings.Contains(text, `"id"`)) + + json := tpl.LabelRuleInfoRes{} + res.JSON(&json) + data := json.Result + assert.True(service.HIDToID(data.HID, "label_rule") > int64(0)) + assert.Equal(label2.ID, service.HIDToID(data.LabelHID, "label")) + assert.Equal("userPercent", data.Kind) + assert.True(data.CreatedAt.UTC().Unix() > int64(0)) + assert.True(data.UpdatedAt.UTC().Unix() > int64(0)) + assert.Equal(int64(1), data.Release) + + rule2 = data + }) + + t.Run(`"PUT /users/:uid/labels:cache" should not apply label2 rules`, func(t *testing.T) { + assert := assert.New(t) + res, err := request.Put(fmt.Sprintf("%s/v1/users/%s/labels:cache?product=%s", tt.Host, user.UID, product.Name)). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + text, err := res.Text() + assert.Nil(err) + assert.False(strings.Contains(text, `"id"`)) + + json := tpl.UserRes{} + _, err = res.JSON(&json) + + assert.Nil(err) + assert.Equal(1, len(json.Result.GetLabels(product.Name))) + + data := json.Result.GetLabels(product.Name) + assert.Equal(label.Name, data[0].Label) + }) + + t.Run(`"PUT /users/:uid/labels:cache" should apply label、label2 rules without product`, func(t *testing.T) { + assert := assert.New(t) + res, err := request.Put(fmt.Sprintf("%s/v1/users/%s/labels:cache", tt.Host, user.UID)). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + text, err := res.Text() + assert.Nil(err) + assert.False(strings.Contains(text, `"id"`)) + + json := tpl.UserRes{} + _, err = res.JSON(&json) + + assert.Nil(err) + assert.Equal(2, len(json.Result.GetLabels(product.Name))) + + data := json.Result.GetLabels(product.Name) + assert.Equal(label2.Name, data[0].Label) + assert.Equal(label.Name, data[1].Label) + }) + t.Run(`"GET /users/:uid/labels:cache" should support anonymous user`, func(t *testing.T) { assert := assert.New(t) res, err := request.Get(fmt.Sprintf("%s/users/%s/labels:cache?product=%s", tt.Host, "anon-"+user.UID, product.Name)). @@ -779,10 +858,10 @@ func TestLabelAPIs(t *testing.T) { _, err = res.JSON(&json) assert.Nil(err) - assert.Equal(1, len(json.Result)) + assert.Equal(2, len(json.Result)) - data := json.Result[0] - assert.Equal(label.Name, data.Label) + assert.Equal(label2.Name, json.Result[0].Label) + assert.Equal(label.Name, json.Result[1].Label) }) t.Run(`"GET /v1/products/:product/labels/:label/rules" should work`, func(t *testing.T) { @@ -877,5 +956,39 @@ func TestLabelAPIs(t *testing.T) { assert.Nil(err) assert.False(json.Result) }) + + t.Run(`"DELETE /v1/products/:product/labels/:label/rules/:hid" label2 should work`, func(t *testing.T) { + assert := assert.New(t) + res, err := request.Delete(fmt.Sprintf("%s/v1/products/%s/labels/%s/rules/%s", tt.Host, product.Name, label2.Name, rule2.HID)). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + json := tpl.BoolRes{} + _, err = res.JSON(&json) + assert.Nil(err) + assert.True(json.Result) + + res, err = request.Get(fmt.Sprintf("%s/v1/products/%s/labels/%s/rules", tt.Host, product.Name, label2.Name)). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + json2 := tpl.LabelRulesInfoRes{} + _, err = res.JSON(&json2) + + assert.Nil(err) + assert.Equal(0, len(json2.Result)) + + res, err = request.Delete(fmt.Sprintf("%s/v1/products/%s/labels/%s/rules/%s", tt.Host, product.Name, label2.Name, rule2.HID)). + End() + assert.Nil(err) + assert.Equal(200, res.StatusCode) + + json = tpl.BoolRes{} + _, err = res.JSON(&json) + assert.Nil(err) + assert.False(json.Result) + }) }) } diff --git a/src/api/user.go b/src/api/user.go index c4ad3e3..83ebe9d 100644 --- a/src/api/user.go +++ b/src/api/user.go @@ -39,12 +39,12 @@ func (a *User) ListCachedLabels(ctx *gear.Context) error { // RefreshCachedLabels 强制更新 user 的 labels 缓存 func (a *User) RefreshCachedLabels(ctx *gear.Context) error { - req := tpl.UIDURL{} + req := tpl.UIDAndProductURL{} if err := ctx.ParseURL(&req); err != nil { return err } - user, err := a.blls.User.RefreshCachedLabels(ctx, req.UID) + user, err := a.blls.User.RefreshCachedLabels(ctx, req.Product, req.UID) if err != nil { return err } diff --git a/src/bll/user.go b/src/bll/user.go index 35e2dd4..451b551 100644 --- a/src/bll/user.go +++ b/src/bll/user.go @@ -66,12 +66,12 @@ func (b *User) ListCachedLabels(ctx context.Context, uid, product string) *tpl.C // user 上缓存的 labels 过期,则刷新获取最新,RefreshUser 要考虑并发场景 if user.ActiveAt == 0 { - if user = b.ms.TryApplyLabelRulesAndRefreshUserLabels(ctx, user.ID, now, true); user == nil { + if user = b.ms.TryApplyLabelRulesAndRefreshUserLabels(ctx, productID, product, user.ID, now, true); user == nil { return res } } else if conf.Config.IsCacheLabelExpired(now.Unix()-5, user.ActiveAt) { // 提前 5s 异步处理 util.Go(10*time.Second, func(gctx context.Context) { - b.ms.TryApplyLabelRulesAndRefreshUserLabels(gctx, user.ID, now, false) + b.ms.TryApplyLabelRulesAndRefreshUserLabels(gctx, productID, product, user.ID, now, false) }) } @@ -81,13 +81,20 @@ func (b *User) ListCachedLabels(ctx context.Context, uid, product string) *tpl.C } // RefreshCachedLabels ... -func (b *User) RefreshCachedLabels(ctx context.Context, uid string) (*schema.User, error) { +func (b *User) RefreshCachedLabels(ctx context.Context, product, uid string) (*schema.User, error) { user, err := b.ms.User.Acquire(ctx, uid) if err != nil { return nil, err } - - if user, err = b.ms.ApplyLabelRulesAndRefreshUserLabels(ctx, user.ID, time.Now().UTC(), true); err != nil { + readCtx := context.WithValue(ctx, model.ReadDB, true) + var productID int64 = 0 + if product != "" { + productID, err = b.ms.Product.AcquireID(readCtx, product) + if err != nil { + return nil, err + } + } + if user, err = b.ms.ApplyLabelRulesAndRefreshUserLabels(ctx, productID, product, user.ID, time.Now().UTC(), true); err != nil { return nil, err } return user, nil diff --git a/src/model/common.go b/src/model/common.go index fa9b740..224de7c 100644 --- a/src/model/common.go +++ b/src/model/common.go @@ -65,10 +65,11 @@ func NewModels(sql *service.SQL) *Models { // ***** 以下为需要组合多个 model 接口能力而对外暴露的接口 ***** // ApplyLabelRulesAndRefreshUserLabels ... -func (ms *Models) ApplyLabelRulesAndRefreshUserLabels(ctx context.Context, userID int64, now time.Time, force bool) (*schema.User, error) { +func (ms *Models) ApplyLabelRulesAndRefreshUserLabels(ctx context.Context, productID int64, product string, userID int64, now time.Time, force bool) (*schema.User, error) { user, labelIDs, ok, err := ms.User.RefreshLabels(ctx, userID, now.Unix(), force) - if ok { - hit, err := ms.LabelRule.ApplyRules(ctx, userID, labelIDs) + userProductLables := user.GetLabels(product) + if ok && len(userProductLables) == 0 { + hit, err := ms.LabelRule.ApplyRules(ctx, productID, userID, labelIDs) if err != nil { return nil, err } @@ -86,8 +87,8 @@ func (ms *Models) ApplyLabelRulesAndRefreshUserLabels(ctx context.Context, userI } // TryApplyLabelRulesAndRefreshUserLabels ... -func (ms *Models) TryApplyLabelRulesAndRefreshUserLabels(ctx context.Context, userID int64, now time.Time, force bool) *schema.User { - user, err := ms.ApplyLabelRulesAndRefreshUserLabels(ctx, userID, now, force) +func (ms *Models) TryApplyLabelRulesAndRefreshUserLabels(ctx context.Context, productID int64, product string, userID int64, now time.Time, force bool) *schema.User { + user, err := ms.ApplyLabelRulesAndRefreshUserLabels(ctx, productID, product, userID, now, force) if err != nil { logging.Warningf("ApplyLabelRulesAndRefreshUserLabels: userID %d, error %v", userID, err) return nil diff --git a/src/model/label_rule.go b/src/model/label_rule.go index 3508422..84bee0b 100644 --- a/src/model/label_rule.go +++ b/src/model/label_rule.go @@ -6,6 +6,7 @@ import ( "time" "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" "github.com/teambition/urbs-setting/src/schema" "github.com/teambition/urbs-setting/src/service" "github.com/teambition/urbs-setting/src/tpl" @@ -18,11 +19,14 @@ type LabelRule struct { } // ApplyRules ... -func (m *LabelRule) ApplyRules(ctx context.Context, userID int64, excludeLabels []int64) (int, error) { +func (m *LabelRule) ApplyRules(ctx context.Context, productID int64, userID int64, excludeLabels []int64) (int, error) { rules := []schema.LabelRule{} // 不把 excludeLabels 放入查询条件,从而尽量复用查询缓存 - sd := m.RdDB.From(schema.TableLabelRule). - Where(goqu.C("kind").Eq("userPercent")).Order(goqu.C("updated_at").Desc()).Limit(200) + exps := []exp.Expression{goqu.C("kind").Eq("userPercent")} + if productID > 0 { + exps = append(exps, goqu.C("product_id").Eq(productID)) + } + sd := m.RdDB.From(schema.TableLabelRule).Where(exps...).Order(goqu.C("updated_at").Desc()).Limit(200) err := sd.Executor().ScanStructsContext(ctx, &rules) if err != nil { return 0, err diff --git a/src/tpl/common.go b/src/tpl/common.go index fa63e2e..29979ed 100644 --- a/src/tpl/common.go +++ b/src/tpl/common.go @@ -84,6 +84,23 @@ func (t *UIDURL) Validate() error { return nil } +// UIDAndProductURL ... +type UIDAndProductURL struct { + UID string `json:"uid" param:"uid"` + Product string `json:"product" query:"product"` +} + +// Validate 实现 gear.BodyTemplate。 +func (t *UIDAndProductURL) Validate() error { + if !validIDReg.MatchString(t.UID) { + return gear.ErrBadRequest.WithMsgf("invalid uid: %s", t.UID) + } + if t.Product != "" && !validNameReg.MatchString(t.Product) { + return gear.ErrBadRequest.WithMsgf("invalid product name: %s", t.Product) + } + return nil +} + // UIDPaginationURL ... type UIDPaginationURL struct { Pagination