From ea59809b0cb67b93a15a5cd24dc6966bf3751b2b Mon Sep 17 00:00:00 2001 From: seelly Date: Sat, 7 Sep 2024 17:12:29 +0800 Subject: [PATCH] Add Redis-based distributed lock with Lua script for atomicity --- biz/dal/mysql/shop.go | 11 ++------ biz/dal/mysql/voucher.go | 2 -- biz/dal/redis/init.go | 9 +++--- biz/dal/redis/lock.go | 31 ++++++++++++++++++--- biz/middleware/interceptor/refresh_token.go | 2 +- biz/service/voucher/seckill_voucher.go | 13 +++++++-- 6 files changed, 47 insertions(+), 21 deletions(-) diff --git a/biz/dal/mysql/shop.go b/biz/dal/mysql/shop.go index 0f4258d..e787bf3 100644 --- a/biz/dal/mysql/shop.go +++ b/biz/dal/mysql/shop.go @@ -101,9 +101,10 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) { if err == nil { return result, nil } - + lock := redis.NewLock(ctx, lockKey, key, 10) // 2. 缓存未命中,尝试获取锁 - isLocked := redis.TryLock(ctx, lockKey) + isLocked := lock.TryLock() + defer lock.UnLock(key) if !isLocked { // 锁获取失败,等待后重试 time.Sleep(50 * time.Millisecond) @@ -113,7 +114,6 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) { // 2.2 获取锁成功,再次检查缓存 result, err = redis.GetShopFromCache(ctx, key) if err == nil { - redis.UnLock(ctx, lockKey) return result, nil } @@ -122,22 +122,17 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) { if err := DB.First(&shop, id).Error; err != nil { if err == gorm.ErrRecordNotFound { redis.RedisClient.Set(ctx, key, "", constants.CACHE_NULL_TTL).Err() - redis.UnLock(ctx, lockKey) return nil, err } - redis.UnLock(ctx, lockKey) return nil, err } // 4. 数据库中存在,缓存数据 shopJson, err := json.Marshal(shop) if err != nil { - redis.UnLock(ctx, lockKey) return nil, err } redis.RedisClient.Set(ctx, key, string(shopJson), constants.CACHE_SHOP_TTL).Err() - - redis.UnLock(ctx, lockKey) return &shop, nil } diff --git a/biz/dal/mysql/voucher.go b/biz/dal/mysql/voucher.go index 69601f6..36ac5c9 100644 --- a/biz/dal/mysql/voucher.go +++ b/biz/dal/mysql/voucher.go @@ -3,7 +3,6 @@ package mysql import ( "context" "errors" - "fmt" "gorm.io/gorm" "xzdp/biz/model/voucher" ) @@ -29,7 +28,6 @@ func QueryVoucherByID(ctx context.Context, id int64) (*voucher.SeckillVoucher, e func QueryVoucherOrderByVoucherID(ctx context.Context, userId int64, id int64) error { var voucherOrder voucher.VoucherOrder err = DB.WithContext(ctx).Where("voucher_id = ? and user_id=?", id, userId).Limit(1).Find(&voucherOrder).Error - fmt.Printf("voucherOrder: %v\n", voucherOrder) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("重复下单") diff --git a/biz/dal/redis/init.go b/biz/dal/redis/init.go index f60bced..33814a2 100644 --- a/biz/dal/redis/init.go +++ b/biz/dal/redis/init.go @@ -59,10 +59,11 @@ func SetStringLogical(ctx context.Context, key string, value interface{}, durati func GetStringLogical(ctx context.Context, key string, duration time.Duration, dbFallback ArgsFunc, args ...interface{}) (string, error) { redisJson, err := RedisClient.Get(ctx, key).Result() + lock := NewLock(ctx, constants.LOCK_KEY+key, "lock", duration) if redisJson == "" || errors.Is(err, redis.Nil) { - if TryLock(ctx, constants.LOCK_KEY+key) { + if lock.TryLock() { go func() { - defer UnLock(ctx, constants.LOCK_KEY+key) + defer lock.UnLock("lock") data, err := dbFallback(args...) if err != nil { return @@ -81,9 +82,9 @@ func GetStringLogical(ctx context.Context, key string, duration time.Duration, d if redisData.ExpiredTime.After(time.Now()) { return redisData.Data, nil } else { - if TryLock(ctx, constants.LOCK_KEY+key) { + if lock.TryLock() { go func() { - defer UnLock(ctx, constants.LOCK_KEY+key) + defer lock.UnLock("lock") data, err := dbFallback(args...) if err != nil { return diff --git a/biz/dal/redis/lock.go b/biz/dal/redis/lock.go index 45e47b3..a8c5640 100644 --- a/biz/dal/redis/lock.go +++ b/biz/dal/redis/lock.go @@ -6,8 +6,25 @@ import ( "time" ) -func TryLock(ctx context.Context, key string) bool { - success, err := RedisClient.SetNX(ctx, key, "1", 10*time.Second).Result() +type Lock struct { + Ctx context.Context + Key string + Value string + Expire time.Duration +} + +func NewLock(ctx context.Context, key string, value string, expire time.Duration) *Lock { + return &Lock{ + Ctx: ctx, + Key: key, + Value: value, + Expire: expire, + } +} + +func (l *Lock) TryLock() bool { + // value 应当全局唯一 + success, err := RedisClient.SetNX(l.Ctx, l.Key, l.Value, l.Expire*time.Second).Result() if err != nil { log.Printf("Error acquiring lock: %v", err) return false @@ -15,6 +32,12 @@ func TryLock(ctx context.Context, key string) bool { return success } -func UnLock(ctx context.Context, key string) { - RedisClient.Del(ctx, key) +func (l *Lock) UnLock(value string) { + luaScript := ` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) + end + return 0 + ` + RedisClient.Eval(l.Ctx, luaScript, []string{l.Key}, value) } diff --git a/biz/middleware/interceptor/refresh_token.go b/biz/middleware/interceptor/refresh_token.go index 08721c8..7639c0c 100644 --- a/biz/middleware/interceptor/refresh_token.go +++ b/biz/middleware/interceptor/refresh_token.go @@ -14,7 +14,7 @@ import ( func CheckToken(ctx context.Context, c *app.RequestContext) { hlog.CtxInfof(ctx, "check token interceptor:%+v", conf.GetEnv()) - if conf.GetEnv() == "dev" { + if conf.GetEnv() != "online" { userdto := model.UserDTO{ ID: 2, NickName: "法外狂徒张三", diff --git a/biz/service/voucher/seckill_voucher.go b/biz/service/voucher/seckill_voucher.go index 9dd655a..bcdf333 100644 --- a/biz/service/voucher/seckill_voucher.go +++ b/biz/service/voucher/seckill_voucher.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/cloudwego/hertz/pkg/app" + "strconv" "sync" "time" "xzdp/biz/dal/mysql" @@ -49,8 +50,16 @@ func (h *SeckillVoucherService) Run(req *int64) (resp *int64, err error) { if voucher.GetStock() <= 0 { return nil, errors.New("已抢空") } - mu.Lock() - defer mu.Unlock() + user := utils.GetUser(h.Context) + uuid, _ := utils.RandomUUID() + sec := time.Now().Unix() + lockValue := uuid + strconv.FormatInt(sec, 10) //由于value的全局唯一性,这里用uuid+时间戳,如需要更高精度应考虑雪花算法活其他方法生成 + lock := redis.NewLock(h.Context, user.NickName, lockValue, 10) + ok := lock.TryLock() + if !ok { + return nil, errors.New("重复下单") + } + defer lock.UnLock(lockValue) return h.createOrder(*req) }