From 359f39ff43164b68768c70080e3dcb5cad744576 Mon Sep 17 00:00:00 2001 From: perror <23651751+perrornet@users.noreply.github.com> Date: Sat, 15 Jun 2024 14:14:16 +0800 Subject: [PATCH] Refactor notice struct and improve http server handling --- cmd/configs.go | 79 ++++++++++++------- cmd/main.go | 49 +++++++++++- internal/daemons/init.go | 6 +- internal/daemons/rebalance/rebalance.go | 31 ++++++-- internal/daemons/utils.go | 2 +- internal/models/order.go | 16 ++-- utils/chains/util.go | 4 +- utils/configs/config.go | 13 +-- utils/notice/notice.go | 4 +- utils/provider/bridge/darwinia/darwinia.go | 1 + utils/provider/bridge/helix/helix.go | 1 + utils/provider/bridge/okx/okx.go | 2 +- .../bridge/routernitro/routernitro.go | 2 +- utils/provider/dex/uniswap/uniswap.go | 2 +- utils/provider/utils.go | 17 ++++ utils/wallets/safe/safe.go | 9 ++- 16 files changed, 175 insertions(+), 63 deletions(-) diff --git a/cmd/configs.go b/cmd/configs.go index 2564945..d0a0e0b 100644 --- a/cmd/configs.go +++ b/cmd/configs.go @@ -12,6 +12,8 @@ import ( yaml_ncoder "github.com/zwgblue/yaml-encoder" "net/http" "omni-balance/internal/daemons" + "omni-balance/internal/db" + "omni-balance/internal/models" "omni-balance/utils/configs" "omni-balance/utils/constant" "os" @@ -26,46 +28,63 @@ var ( setPlaceholderFinished = make(chan struct{}, 1) ) -func startHttpServer(ctx context.Context, port string) (func(ctx context.Context) error, error) { - server := &http.Server{ - Addr: port, - Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if !strings.EqualFold(request.Method, http.MethodPost) { - writer.WriteHeader(http.StatusMethodNotAllowed) - return - } - var args = make(map[string]interface{}) - if err := json.NewDecoder(request.Body).Decode(&args); err != nil { - writer.WriteHeader(http.StatusBadRequest) - return - } - for k, v := range args { - placeholder.Store(k, v) - } +func startHttpServer(_ context.Context, port string) error { + http.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if !strings.EqualFold(request.Method, http.MethodPost) { + writer.WriteHeader(http.StatusMethodNotAllowed) + return + } + var args = make(map[string]interface{}) + if err := json.NewDecoder(request.Body).Decode(&args); err != nil { + writer.WriteHeader(http.StatusBadRequest) + return + } + for k, v := range args { + placeholder.Store(k, v) + } + + setPlaceholderFinished <- struct{}{} + })) - setPlaceholderFinished <- struct{}{} - }), + http.Handle("/remove_order", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if !strings.EqualFold(request.Method, http.MethodPost) { + writer.WriteHeader(http.StatusMethodNotAllowed) + return + } + var order = struct { + Id int `json:"id" form:"id"` + }{} + if err := json.NewDecoder(request.Body).Decode(&order); err != nil { + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte(err.Error())) + return + } + err := db.DB().Model(&models.Order{}).Where("id = ?", order.Id).Limit(1).Delete(&models.Order{}).Error + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + return + } + writer.WriteHeader(http.StatusOK) + })) + server := &http.Server{ + Addr: port, + Handler: http.DefaultServeMux, } go func() { if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { logrus.Panic(err) } }() - return server.Shutdown, nil + logrus.Infof("http server started on %s", port) + return nil } -func waitForPlaceholder(ctx context.Context, configPath, port string) (newConfigPath string, err error) { +func waitForPlaceholder(_ context.Context, configPath string) (newConfigPath string, err error) { data, err := os.ReadFile(configPath) if err != nil { return "", err } - shutdown, err := startHttpServer(ctx, port) - if err != nil { - return "", err - } - defer func() { - _ = shutdown(ctx) - }() + <-setPlaceholderFinished placeholder.Range(func(key, value interface{}) bool { data = bytes.ReplaceAll(data, []byte(key.(string)), []byte(cast.ToString(value))) @@ -79,13 +98,17 @@ func waitForPlaceholder(ctx context.Context, configPath, port string) (newConfig } func initConfig(ctx context.Context, enablePlaceholder bool, configPath, serverPort string) (err error) { + err = startHttpServer(ctx, serverPort) + if err != nil { + return err + } if enablePlaceholder { ports := strings.Split(serverPort, ":") if len(ports) < 2 { ports = append([]string{}, "", "8080") } logrus.Infof("waiting for placeholder, you can use `curl -X POST -d '{\"\":\"0x1234567890\"}' http://127.0.0.1:%s` to set placeholder", ports[1]) - configPath, err = waitForPlaceholder(context.Background(), configPath, serverPort) + configPath, err = waitForPlaceholder(context.Background(), configPath) if err != nil { return err } diff --git a/cmd/main.go b/cmd/main.go index 40ae8f3..6bfec12 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,12 +1,17 @@ package main import ( + "bytes" "context" + "encoding/json" "flag" "fmt" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" + "io" + "net/http" + "net/url" "omni-balance/internal/daemons" _ "omni-balance/internal/daemons/cross_chain" _ "omni-balance/internal/daemons/monitor" @@ -80,10 +85,9 @@ func Action(cli *cli.Context) error { logrus.SetFormatter(&logrus.JSONFormatter{}) } - if err := notice.Init(notice.Type(config.Notice.Type), config.Notice.Config); err != nil { + if err := notice.Init(notice.Type(config.Notice.Type), config.Notice.Config, config.Notice.Interval); err != nil { logrus.Warnf("init notice error: %v", err) } - notice.SetMsgInterval(config.Notice.Interval) if err := db.InitDb(*config); err != nil { return errors.Wrap(err, "init db") @@ -114,6 +118,47 @@ func main() { app.Name = "omni-balance" app.Action = Action app.Commands = []*cli.Command{ + { + Name: "del_order", + Usage: "delete order by id", + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "id", + Usage: "order id", + }, + &cli.StringFlag{ + Name: "server", + Usage: "server host, example: http://127.0.0.1:8080", + Value: "http://127.0.0.1:8080", + }, + }, + Action: func(c *cli.Context) error { + u, err := url.Parse(c.String("server")) + if err != nil { + return errors.Wrap(err, "parse server url") + } + u.RawPath = "/remove_order" + u.Path = u.RawPath + var body = bytes.NewBuffer(nil) + err = json.NewEncoder(body).Encode(map[string]interface{}{ + "id": c.Int("id"), + }) + if err != nil { + return errors.Wrap(err, "encode body") + } + resp, err := http.Post(u.String(), "application/json", body) + if err != nil { + return errors.Wrap(err, "post") + } + defer resp.Body.Close() + data, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return errors.Errorf("http status code: %d, body is: %s", resp.StatusCode, data) + } + logrus.Infof("delete order #%d success", c.Int64("id")) + return nil + }, + }, { Name: "version", Usage: "show version", diff --git a/internal/daemons/init.go b/internal/daemons/init.go index 607ca9c..cb1104b 100644 --- a/internal/daemons/init.go +++ b/internal/daemons/init.go @@ -103,17 +103,17 @@ func runForever(ctx context.Context, conf configs.Config, task Task) { func Run(ctx context.Context, conf configs.Config) error { for index := range tasks { if tasks[index].RunOnStart { - logrus.Infof("task %s run on start, wait for the task finished", tasks[index].Name) + logrus.Debugf("task %s run on start, wait for the task finished", tasks[index].Name) if err := tasks[index].TaskFunc(ctx, conf); err != nil { logrus.Errorf("task %s failed, err: %v", tasks[index].Name, err) continue } - logrus.Infof("task %s run on start finished", tasks[index].Name) + logrus.Debugf("task %s run on start finished", tasks[index].Name) continue } } for index := range tasks { - logrus.Infof("task %s run in background", tasks[index].Name) + logrus.Debugf("task %s run in background", tasks[index].Name) go runForever(ctx, conf, tasks[index]) } return nil diff --git a/internal/daemons/rebalance/rebalance.go b/internal/daemons/rebalance/rebalance.go index b18ff2c..b81fd1e 100644 --- a/internal/daemons/rebalance/rebalance.go +++ b/internal/daemons/rebalance/rebalance.go @@ -46,8 +46,29 @@ func Run(ctx context.Context, conf configs.Config) error { go func(order *models.Order) { defer w.Done() log := order.GetLogs() - utils.SetLogToCtx(ctx, log) - if err := reBalance(ctx, order, conf); err != nil { + subCtx, cancel := context.WithCancel(utils.SetLogToCtx(ctx, log)) + defer cancel() + + go func() { + defer cancel() + var t = time.NewTicker(time.Second * 5) + defer t.Stop() + + for { + select { + case <-subCtx.Done(): + return + case <-t.C: + var count int64 + _ = db.DB().Model(&models.Order{}).Where("id = ?", order.ID).Count(&count) + if count == 0 { + log.Infof("order #%d not found, exit this order rebalance", order.ID) + return + } + } + } + }() + if err := reBalance(subCtx, order, conf); err != nil { log.Errorf("reBalance order #%d error: %s", order.ID, err) return } @@ -149,7 +170,8 @@ func reBalance(ctx context.Context, order *models.Order, conf configs.Config) er return errors.Wrap(err, "save provider error") } - log.Infof("start reBalance %s on %s use %s provider", order.TokenOutName, order.TargetChainName, providerObj.Name()) + log.Infof("start reBalance #%d %s on %s use %s provider", order.ID, order.TokenOutName, + order.TargetChainName, providerObj.Name()) result, err := providerObj.Swap(ctx, args) if err != nil { return errors.Wrapf(err, "reBalance %s on %s error", order.TokenOutName, providerObj.Name()) @@ -321,12 +343,11 @@ func getReBalanceProvider(ctx context.Context, order models.Order, conf configs. } func providerSupportsOrder(ctx context.Context, p provider.Provider, order models.Order, conf configs.Config, log *logrus.Entry) (provider.TokenInCosts, bool) { - wallet := conf.GetWallet(order.Wallet) tokenInCosts, err := p.GetCost(ctx, provider.SwapParams{ SourceToken: order.TokenInName, Sender: conf.GetWallet(order.Wallet), TargetToken: order.TokenOutName, - Receiver: wallet.GetAddress().Hex(), + Receiver: order.Wallet, TargetChain: order.TargetChainName, Amount: order.Amount, }) diff --git a/internal/daemons/utils.go b/internal/daemons/utils.go index 5823837..b2fad49 100644 --- a/internal/daemons/utils.go +++ b/internal/daemons/utils.go @@ -13,7 +13,7 @@ func CreateSwapParams(order models.Order, orderProcess models.OrderProcess, log return provider.SwapParams{ SourceChain: order.CurrentChainName, Sender: wallet, - Receiver: wallet.GetAddress().Hex(), + Receiver: order.Wallet, TargetChain: order.TargetChainName, SourceToken: order.TokenOutName, TargetToken: order.TokenOutName, diff --git a/internal/models/order.go b/internal/models/order.go index de38dd9..7fc3e44 100644 --- a/internal/models/order.go +++ b/internal/models/order.go @@ -6,6 +6,7 @@ import ( "github.com/shopspring/decimal" "github.com/sirupsen/logrus" "gorm.io/gorm" + "omni-balance/utils" "omni-balance/utils/configs" "time" ) @@ -13,10 +14,10 @@ import ( type OrderStatus string const ( - OrderStatusWait OrderStatus = "wait" - OrderStatusProcessing OrderStatus = "processing" - OrderStatusSuccess OrderStatus = "success" - OrderStatusFail OrderStatus = "fail" + OrderStatusWait OrderStatus = "wait" + //OrderStatusProcessing OrderStatus = "processing" + OrderStatusSuccess OrderStatus = "success" + //OrderStatusFail OrderStatus = "fail" OrderStatusWaitTransferFromOperator OrderStatus = "wait_transfer_from_operator" OrderStatusWaitCrossChain OrderStatus = "wait_cross_chain" OrderStatusUnknown OrderStatus = "unknown" @@ -95,8 +96,7 @@ func GetLastOrderProcess(ctx context.Context, db *gorm.DB, orderId uint) OrderPr } func (o *Order) GetLogs() *logrus.Entry { - data, _ := json.Marshal(o) - var fields logrus.Fields - _ = json.Unmarshal(data, &fields) - return logrus.WithFields(fields) + return logrus.WithFields(logrus.Fields{ + "order": utils.ToMap(o), + }) } diff --git a/utils/chains/util.go b/utils/chains/util.go index 7e8f2f0..5e6d103 100644 --- a/utils/chains/util.go +++ b/utils/chains/util.go @@ -123,7 +123,7 @@ func WaitForTx(ctx context.Context, client simulated.Client, txHash common.Hash) return err } if errors.Is(err, ethereum.NotFound) { - log.Debugf("tx not found, txHash: %s, try again later", txHash.Hex()) + log.Infof("tx not found, txHash: %s, try again later", txHash.Hex()) continue } if err != nil { @@ -133,7 +133,7 @@ func WaitForTx(ctx context.Context, client simulated.Client, txHash common.Hash) if tx.Status == 0 { return errors.New("tx failed") } - log.Debugf("tx success, txHash: %s", txHash.Hex()) + log.Infof("tx success, txHash: %s", txHash.Hex()) return nil } } diff --git a/utils/configs/config.go b/utils/configs/config.go index cac6674..95ee890 100644 --- a/utils/configs/config.go +++ b/utils/configs/config.go @@ -58,12 +58,13 @@ type Config struct { TaskInterval map[string]time.Duration `json:"task_interval" yaml:"task_interval"` - // 通知渠道, 当成功rebalance时, 发送通知 - Notice struct { - Type string `json:"type" yaml:"type" comment:"Notice type, support: slack"` - Config map[string]interface{} `json:"config" yaml:"config" comment:"It depends on the notification type, slack needs ['webhook','channel']"` - Interval time.Duration `json:"interval" yaml:"interval" comment:"Same message send interval, minimum interval must be greater than or equal to 1 hour, default 1h"` - } `json:"notice" yaml:"notice" comment:"Notice config. When rebalance success, send notice"` + Notice Notice `json:"notice" yaml:"notice" comment:"Notice config. When rebalance success, send notice"` +} + +type Notice struct { + Type string `json:"type" yaml:"type" comment:"Notice type, support: slack"` + Config map[string]interface{} `json:"config" yaml:"config" comment:"It depends on the notification type, slack needs ['webhook','channel']"` + Interval time.Duration `json:"interval" yaml:"interval" comment:"Same message send interval, minimum interval must be greater than or equal to 1 hour, default 1h"` } type Chain struct { diff --git a/utils/notice/notice.go b/utils/notice/notice.go index a31bea1..1f16e1b 100644 --- a/utils/notice/notice.go +++ b/utils/notice/notice.go @@ -37,6 +37,7 @@ type Notice interface { func SetMsgInterval(interval time.Duration) { if interval.Seconds() < time.Hour.Seconds() { + logrus.Warnf("msg interval %s is too short, set to 1 hour", interval) msgInterval = time.Hour return } @@ -47,7 +48,7 @@ func WithFields(ctx context.Context, fields Fields) context.Context { return context.WithValue(ctx, constant.NoticeFieldsKeyInCtx, fields) } -func Init(noticeType Type, conf map[string]interface{}) error { +func Init(noticeType Type, conf map[string]interface{}, interval time.Duration) error { if notice != nil { return nil } @@ -64,6 +65,7 @@ func Init(noticeType Type, conf map[string]interface{}) error { return errors.Errorf("notice type %s not support", noticeType) } } + SetMsgInterval(interval) return nil } diff --git a/utils/provider/bridge/darwinia/darwinia.go b/utils/provider/bridge/darwinia/darwinia.go index a1505e6..b5c3b39 100644 --- a/utils/provider/bridge/darwinia/darwinia.go +++ b/utils/provider/bridge/darwinia/darwinia.go @@ -216,6 +216,7 @@ func (b *Bridge) Swap(ctx context.Context, args provider.SwapParams) (result pro recordFn(provider.SwapHistory{Actions: sourceChainSendingAction, Status: string(provider.TxStatusPending), CurrentChain: args.SourceChain}) ctx = provider.WithNotify(ctx, provider.WithNotifyParams{ + Receiver: common.HexToAddress(args.Receiver), TokenIn: args.SourceToken, TokenOut: args.TargetToken, TokenInChain: args.SourceChain, diff --git a/utils/provider/bridge/helix/helix.go b/utils/provider/bridge/helix/helix.go index 5bd15b8..048ae5a 100644 --- a/utils/provider/bridge/helix/helix.go +++ b/utils/provider/bridge/helix/helix.go @@ -113,6 +113,7 @@ func (b *Bridge) Swap(ctx context.Context, args provider.SwapParams) (result pro tx.Gas = 406775 ctx = provider.WithNotify(ctx, provider.WithNotifyParams{ + Receiver: common.HexToAddress(args.Receiver), TokenIn: args.SourceToken, TokenOut: args.TargetToken, TokenInChain: args.SourceChain, diff --git a/utils/provider/bridge/okx/okx.go b/utils/provider/bridge/okx/okx.go index e810fee..26dac94 100644 --- a/utils/provider/bridge/okx/okx.go +++ b/utils/provider/bridge/okx/okx.go @@ -194,6 +194,7 @@ func (o *OKX) Swap(ctx context.Context, args provider.SwapParams) (provider.Swap amount := args.Amount.Copy() args.Amount = tokenInAmount ctx = provider.WithNotify(ctx, provider.WithNotifyParams{ + Receiver: common.HexToAddress(args.Receiver), TokenIn: tokenIn.Name, TokenOut: tokenOut.Name, TokenInChain: args.SourceChain, @@ -232,7 +233,6 @@ func (o *OKX) Swap(ctx context.Context, args provider.SwapParams) (provider.Swap sh = sh.SetActions(SourceChainSendingAction) args.RecordFn(sh.SetStatus(provider.TxStatusPending).Out()) log.Debug("sending tx on chain") - //return provider.SwapResult{}, nil txHash, err := args.Sender.SendTransaction(ctx, &types.LegacyTx{ To: &buildTx.Tx.To, Value: buildTx.Tx.Value.BigInt(), diff --git a/utils/provider/bridge/routernitro/routernitro.go b/utils/provider/bridge/routernitro/routernitro.go index 0ede9c6..0611335 100644 --- a/utils/provider/bridge/routernitro/routernitro.go +++ b/utils/provider/bridge/routernitro/routernitro.go @@ -186,6 +186,7 @@ func (r Routernitro) Swap(ctx context.Context, args provider.SwapParams) (provid amount := args.Amount.Copy() args.Amount = tokenInAmount ctx = provider.WithNotify(ctx, provider.WithNotifyParams{ + Receiver: common.HexToAddress(args.Receiver), TokenIn: tokenIn.Name, TokenOut: tokenOut.Name, TokenInChain: args.SourceChain, @@ -240,7 +241,6 @@ func (r Routernitro) Swap(ctx context.Context, args provider.SwapParams) (provid sh = sh.SetActions(SourceChainSendingAction) args.RecordFn(sh.SetStatus(provider.TxStatusPending).Out()) log.Debug("sending tx on chain") - //return provider.SwapResult{}, nil txHash, err := args.Sender.SendTransaction(ctx, &types.LegacyTx{ To: &buildTx.Txn.To, Value: value.BigInt(), diff --git a/utils/provider/dex/uniswap/uniswap.go b/utils/provider/dex/uniswap/uniswap.go index fb76648..f417084 100644 --- a/utils/provider/dex/uniswap/uniswap.go +++ b/utils/provider/dex/uniswap/uniswap.go @@ -138,7 +138,7 @@ func (u *Uniswap) Swap(ctx context.Context, args provider.SwapParams) (result pr ProviderName: u.Name(), TokenInAmount: args.Amount, TokenOutAmount: args.Amount, - TransactionType: provider.SwapTransactionAction, + TransactionType: provider.ApproveTransactionAction, }) if err := chains.TokenApprove(ctx, chains.TokenApproveParams{ ChainId: int64(chain.Id), diff --git a/utils/provider/utils.go b/utils/provider/utils.go index 006f40e..1a14c1a 100644 --- a/utils/provider/utils.go +++ b/utils/provider/utils.go @@ -11,6 +11,7 @@ import ( "omni-balance/utils" "omni-balance/utils/chains" "omni-balance/utils/configs" + "omni-balance/utils/constant" "omni-balance/utils/notice" ) @@ -68,6 +69,17 @@ func Transfer(ctx context.Context, conf configs.Config, args SwapParams, client actionNumber = 0 args.LastHistory.Status = "" } + ctx = WithNotify(ctx, WithNotifyParams{ + Receiver: common.HexToAddress(args.Receiver), + TokenIn: args.SourceToken, + TokenOut: args.TargetToken, + TokenInChain: args.SourceChain, + TokenOutChain: args.TargetChain, + ProviderName: "transfer", + TokenInAmount: args.Amount, + TokenOutAmount: args.Amount, + TransactionType: TransferTransactionAction, + }) var txHash = last.Tx if actionNumber < 1 && args.LastHistory.Status != TxStatusSuccess.String() { @@ -170,6 +182,7 @@ func GetTokenCrossChainProviders(ctx context.Context, args GetTokenCrossChainPro } type WithNotifyParams struct { + Receiver common.Address TokenIn, TokenOut, TokenInChain, TokenOutChain, ProviderName string TokenInAmount, TokenOutAmount decimal.Decimal TransactionType TransactionType @@ -199,5 +212,9 @@ func WithNotify(ctx context.Context, args WithNotifyParams) context.Context { fields["tokenOut"] = fmt.Sprintf("%s on %s", fields["tokenOut"], args.TokenOutChain) } } + + if args.Receiver.Cmp(constant.ZeroAddress) != 0 { + fields["receiver"] = args.Receiver.Hex() + } return notice.WithFields(ctx, fields) } diff --git a/utils/wallets/safe/safe.go b/utils/wallets/safe/safe.go index df0b19a..cf489ee 100644 --- a/utils/wallets/safe/safe.go +++ b/utils/wallets/safe/safe.go @@ -118,7 +118,7 @@ func (s *Safe) SendTransaction(ctx context.Context, tx *types.LegacyTx, client s func (s *Safe) WaitTransaction(ctx context.Context, txHash common.Hash, _ simulated.Client) error { var ( - log = utils.GetLogFromCtx(ctx) + log = utils.GetLogFromCtx(ctx).WithFields(utils.ToMap(s)) t = time.NewTicker(time.Second * 2) count = 0 ) @@ -138,10 +138,11 @@ func (s *Safe) WaitTransaction(ctx context.Context, txHash common.Hash, _ simula } if len(tx.Confirmations) < tx.ConfirmationsRequired { count = 0 - log.Debugf("transaction %s confirmations: %d, required: %d, try to sending notice", - txHash, len(tx.Confirmations), tx.ConfirmationsRequired) + log.Infof("%s transaction %s confirmations: %d, required: %d,", + tx.Safe, txHash, len(tx.Confirmations), tx.ConfirmationsRequired) if err := notice.Send(ctx, - fmt.Sprintf("wait %s safeHash %s confirmations and execute.", constant.GetChainName(s.GetChainIdByCtx(ctx)), txHash), + fmt.Sprintf("wait %s safeHash %s confirmations and execute.", + constant.GetChainName(s.GetChainIdByCtx(ctx)), txHash), fmt.Sprintf("Please go to %s %s safe address to confirm and execute #%d transaction.", constant.GetChainName(s.GetChainIdByCtx(ctx)), tx.Safe, tx.Nonce), logrus.WarnLevel,