Skip to content

Commit

Permalink
Add Redis-based distributed lock with Lua script for atomicity
Browse files Browse the repository at this point in the history
  • Loading branch information
Seelly committed Sep 7, 2024
1 parent 814037f commit ea59809
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
11 changes: 3 additions & 8 deletions biz/dal/mysql/shop.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) {
if err == nil {
return result, nil
}

lock := redis.NewLock(ctx, lockKey, key, 10)
// 2. 缓存未命中,尝试获取锁
isLocked := redis.TryLock(ctx, lockKey)
isLocked := lock.TryLock()
defer lock.UnLock(key)
if !isLocked {
// 锁获取失败,等待后重试
time.Sleep(50 * time.Millisecond)
Expand All @@ -113,7 +114,6 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) {
// 2.2 获取锁成功,再次检查缓存
result, err = redis.GetShopFromCache(ctx, key)
if err == nil {
redis.UnLock(ctx, lockKey)
return result, nil
}

Expand All @@ -122,22 +122,17 @@ func queryByID2(ctx context.Context, id int64) (*shop.Shop, error) {
if err := DB.First(&shop, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
redis.RedisClient.Set(ctx, key, "", constants.CACHE_NULL_TTL).Err()
redis.UnLock(ctx, lockKey)
return nil, err
}
redis.UnLock(ctx, lockKey)
return nil, err
}

// 4. 数据库中存在,缓存数据
shopJson, err := json.Marshal(shop)
if err != nil {
redis.UnLock(ctx, lockKey)
return nil, err
}
redis.RedisClient.Set(ctx, key, string(shopJson), constants.CACHE_SHOP_TTL).Err()

redis.UnLock(ctx, lockKey)
return &shop, nil
}

Expand Down
2 changes: 0 additions & 2 deletions biz/dal/mysql/voucher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package mysql
import (
"context"
"errors"
"fmt"
"gorm.io/gorm"
"xzdp/biz/model/voucher"
)
Expand All @@ -29,7 +28,6 @@ func QueryVoucherByID(ctx context.Context, id int64) (*voucher.SeckillVoucher, e
func QueryVoucherOrderByVoucherID(ctx context.Context, userId int64, id int64) error {
var voucherOrder voucher.VoucherOrder
err = DB.WithContext(ctx).Where("voucher_id = ? and user_id=?", id, userId).Limit(1).Find(&voucherOrder).Error
fmt.Printf("voucherOrder: %v\n", voucherOrder)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("重复下单")
Expand Down
9 changes: 5 additions & 4 deletions biz/dal/redis/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ func SetStringLogical(ctx context.Context, key string, value interface{}, durati

func GetStringLogical(ctx context.Context, key string, duration time.Duration, dbFallback ArgsFunc, args ...interface{}) (string, error) {
redisJson, err := RedisClient.Get(ctx, key).Result()
lock := NewLock(ctx, constants.LOCK_KEY+key, "lock", duration)
if redisJson == "" || errors.Is(err, redis.Nil) {
if TryLock(ctx, constants.LOCK_KEY+key) {
if lock.TryLock() {
go func() {
defer UnLock(ctx, constants.LOCK_KEY+key)
defer lock.UnLock("lock")
data, err := dbFallback(args...)
if err != nil {
return
Expand All @@ -81,9 +82,9 @@ func GetStringLogical(ctx context.Context, key string, duration time.Duration, d
if redisData.ExpiredTime.After(time.Now()) {
return redisData.Data, nil
} else {
if TryLock(ctx, constants.LOCK_KEY+key) {
if lock.TryLock() {
go func() {
defer UnLock(ctx, constants.LOCK_KEY+key)
defer lock.UnLock("lock")
data, err := dbFallback(args...)
if err != nil {
return
Expand Down
31 changes: 27 additions & 4 deletions biz/dal/redis/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,38 @@ import (
"time"
)

func TryLock(ctx context.Context, key string) bool {
success, err := RedisClient.SetNX(ctx, key, "1", 10*time.Second).Result()
type Lock struct {
Ctx context.Context
Key string
Value string
Expire time.Duration
}

func NewLock(ctx context.Context, key string, value string, expire time.Duration) *Lock {
return &Lock{
Ctx: ctx,
Key: key,
Value: value,
Expire: expire,
}
}

func (l *Lock) TryLock() bool {
// value 应当全局唯一
success, err := RedisClient.SetNX(l.Ctx, l.Key, l.Value, l.Expire*time.Second).Result()
if err != nil {
log.Printf("Error acquiring lock: %v", err)
return false
}
return success
}

func UnLock(ctx context.Context, key string) {
RedisClient.Del(ctx, key)
func (l *Lock) UnLock(value string) {
luaScript := `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
end
return 0
`
RedisClient.Eval(l.Ctx, luaScript, []string{l.Key}, value)
}
2 changes: 1 addition & 1 deletion biz/middleware/interceptor/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func CheckToken(ctx context.Context, c *app.RequestContext) {
hlog.CtxInfof(ctx, "check token interceptor:%+v", conf.GetEnv())
if conf.GetEnv() == "dev" {
if conf.GetEnv() != "online" {
userdto := model.UserDTO{
ID: 2,
NickName: "法外狂徒张三",
Expand Down
13 changes: 11 additions & 2 deletions biz/service/voucher/seckill_voucher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/cloudwego/hertz/pkg/app"
"strconv"
"sync"
"time"
"xzdp/biz/dal/mysql"
Expand Down Expand Up @@ -49,8 +50,16 @@ func (h *SeckillVoucherService) Run(req *int64) (resp *int64, err error) {
if voucher.GetStock() <= 0 {
return nil, errors.New("已抢空")
}
mu.Lock()
defer mu.Unlock()
user := utils.GetUser(h.Context)
uuid, _ := utils.RandomUUID()
sec := time.Now().Unix()
lockValue := uuid + strconv.FormatInt(sec, 10) //由于value的全局唯一性,这里用uuid+时间戳,如需要更高精度应考虑雪花算法活其他方法生成
lock := redis.NewLock(h.Context, user.NickName, lockValue, 10)
ok := lock.TryLock()
if !ok {
return nil, errors.New("重复下单")
}
defer lock.UnLock(lockValue)
return h.createOrder(*req)
}

Expand Down

0 comments on commit ea59809

Please sign in to comment.