diff --git a/pkg/api/client/client.go b/pkg/api/client/client.go index b564d5dd..8d088f72 100644 --- a/pkg/api/client/client.go +++ b/pkg/api/client/client.go @@ -134,3 +134,82 @@ func LocalClient( return string(b), nil } + +func WaitNotification( + cfg *config.Instance, + id string, +) (string, error) { + u := url.URL{ + Scheme: "ws", + Host: "localhost:" + strconv.Itoa(cfg.ApiPort()), + Path: "/api/v1.0", + } + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return "", err + } + defer func(c *websocket.Conn) { + err := c.Close() + if err != nil { + log.Warn().Err(err).Msg("error closing websocket") + } + }(c) + + done := make(chan struct{}) + var resp *models.RequestObject + + go func() { + defer close(done) + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Error().Err(err).Msg("error reading message") + return + } + + var m models.RequestObject + err = json.Unmarshal(message, &m) + if err != nil { + continue + } + + if m.JsonRpc != "2.0" { + log.Error().Msg("invalid jsonrpc version") + continue + } + + if m.Id != nil { + continue + } + + if m.Method != id { + continue + } + + resp = &m + + return + } + }() + + timer := time.NewTimer(api.RequestTimeout) + select { + case <-done: + break + case <-timer.C: + return "", ErrRequestTimeout + } + + if resp == nil { + return "", ErrRequestTimeout + } + + var b []byte + b, err = json.Marshal(resp.Params) + if err != nil { + return "", err + } + + return string(b), nil +} diff --git a/pkg/api/methods/mappings.go b/pkg/api/methods/mappings.go index 65b41134..8bfe1311 100644 --- a/pkg/api/methods/mappings.go +++ b/pkg/api/methods/mappings.go @@ -5,6 +5,8 @@ import ( "errors" "github.com/ZaparooProject/zaparoo-core/pkg/api/models" "github.com/ZaparooProject/zaparoo-core/pkg/api/models/requests" + "github.com/ZaparooProject/zaparoo-core/pkg/platforms" + "path/filepath" "regexp" "strconv" "time" @@ -215,3 +217,16 @@ func HandleUpdateMapping(env requests.RequestEnv) (any, error) { return nil, nil } + +func HandleReloadMappings(env requests.RequestEnv) (any, error) { + log.Info().Msg("received reload mappings request") + + mapDir := filepath.Join(env.Platform.DataDir(), platforms.MappingsDir) + err := env.Config.LoadMappings(mapDir) + if err != nil { + log.Error().Err(err).Msg("error loading mappings") + return nil, errors.New("error loading mappings") + } + + return nil, nil +} diff --git a/pkg/api/models/models.go b/pkg/api/models/models.go index 11f3947b..c3d8e2da 100644 --- a/pkg/api/models/models.go +++ b/pkg/api/models/models.go @@ -25,6 +25,7 @@ const ( MethodMappingsNew = "mappings.new" MethodMappingsDelete = "mappings.delete" MethodMappingsUpdate = "mappings.update" + MethodMappingsReload = "mappings.reload" MethodReadersWrite = "readers.write" MethodStatus = "status" MethodVersion = "version" diff --git a/pkg/api/server.go b/pkg/api/server.go index 407e0894..d7b4ccc0 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -52,6 +52,7 @@ var methodMap = map[string]func(requests.RequestEnv) (any, error){ models.MethodMappingsNew: methods.HandleAddMapping, models.MethodMappingsDelete: methods.HandleDeleteMapping, models.MethodMappingsUpdate: methods.HandleUpdateMapping, + models.MethodMappingsReload: methods.HandleReloadMappings, // readers models.MethodReadersWrite: methods.HandleReaderWrite, // utils diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 467647eb..38c9db7c 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -15,11 +15,15 @@ import ( "github.com/rs/zerolog/log" "io" "os" + "os/signal" "strings" + "syscall" ) type Flags struct { Write *string + Read *bool + Run *string Launch *string Api *string Clients *bool @@ -35,38 +39,48 @@ func SetupFlags() *Flags { Write: flag.String( "write", "", - "write text to tag using connected reader", + "write value to next scanned token", + ), + Read: flag.Bool( + "read", + false, + "print next scanned token without running", + ), + Run: flag.String( + "run", + "", + "run value directly as ZapScript", ), Launch: flag.String( "launch", "", - "launch text as if it were a scanned token", + "alias of run (DEPRECATED)", ), Api: flag.String( "api", "", "send method and params to API and print response", ), - Clients: flag.Bool( - "clients", - false, - "list all registered API clients and secrets", - ), - NewClient: flag.String( - "new-client", - "", - "register new API client with given display name", - ), - DeleteClient: flag.String( - "delete-client", - "", - "revoke access to API for given client ID", - ), - Qr: flag.Bool( - "qr", - false, - "output a connection QR code along with client details", - ), + //Clients: flag.Bool( + // "clients", + // false, + // "list all registered API clients and secrets", + //), + //NewClient: flag.String( + // "new-client", + // "", + // "register new API client with given display name", + //), + //DeleteClient: flag.String( + // "delete-client", + // "", + // "revoke access to API for given client ID", + //), + //Qr: flag.Bool( + // "qr", + // false, + // "output a connection QR code along with client details", + //), Version: flag.Bool( "version", false, @@ -112,9 +126,55 @@ func (f *Flags) Post(cfg *config.Instance) { } else { os.Exit(0) } - } else if *f.Launch != "" { + } else if *f.Read { + enableRun := func() { + _, err := client.LocalClient( + cfg, + models.MethodSettingsUpdate, + "{\"launchingActive\":true}", + ) + if err != nil { + log.Error().Err(err).Msg("error re-enabling run") + _, _ = fmt.Fprintf(os.Stderr, "Error re-enabling run: %v\n", err) + os.Exit(1) + } + } + + _, err := client.LocalClient( + cfg, + models.MethodSettingsUpdate, + "{\"launchingActive\":false}", + ) + if err != nil { + log.Error().Err(err).Msg("error disabling run") + _, _ = fmt.Fprintf(os.Stderr, "Error disabling run: %v\n", err) + os.Exit(1) + } + + // cleanup after ctrl-c + sigs := make(chan os.Signal, 1) + defer close(sigs) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + enableRun() + os.Exit(0) + }() + + resp, err := client.WaitNotification(cfg, models.TokensActive) + if err != nil { + log.Error().Err(err).Msg("error waiting for notification") + _, _ = fmt.Fprintf(os.Stderr, "Error waiting for notification: %v\n", err) + enableRun() + os.Exit(1) + } + + enableRun() + fmt.Println(resp) + os.Exit(0) + } else if *f.Run != "" || *f.Launch != "" { data, err := json.Marshal(&models.LaunchParams{ - Text: f.Launch, + Text: f.Run, }) if err != nil { _, _ = fmt.Fprintf(os.Stderr, "Error encoding params: %v\n", err) diff --git a/pkg/config/config.go b/pkg/config/config.go index 15c4d68e..e2ad0102 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,6 +30,7 @@ type Values struct { Launchers Launchers `toml:"launchers,omitempty"` ZapScript ZapScript `toml:"zapscript,omitempty"` Service Service `toml:"service,omitempty"` + Mappings Mappings `toml:"mappings,omitempty"` } type Audio struct { @@ -77,6 +78,16 @@ type Service struct { AllowLaunch []string `toml:"allow_launch,omitempty,multiline"` } +type MappingsEntry struct { + TokenKey string `toml:"token_key,omitempty"` + MatchPattern string `toml:"match_pattern"` + ZapScript string `toml:"zapscript"` +} + +type Mappings struct { + Entry []MappingsEntry `toml:"entry,omitempty"` +} + var BaseDefaults = Values{ ConfigSchema: SchemaVersion, Audio: Audio{ @@ -346,3 +357,60 @@ func (c *Instance) IsShellCmdAllowed(cmd string) bool { } return false } + +func (c *Instance) LoadMappings(mappingsDir string) error { + c.mu.Lock() + defer c.mu.Unlock() + + _, err := os.Stat(mappingsDir) + if err != nil { + return err + } + + mapFiles, err := os.ReadDir(mappingsDir) + if err != nil { + return err + } + + filesCounts := 0 + mappingsCount := 0 + + for _, mapFile := range mapFiles { + if mapFile.IsDir() { + continue + } + + if filepath.Ext(mapFile.Name()) != ".toml" { + continue + } + + mapPath := filepath.Join(mappingsDir, mapFile.Name()) + log.Debug().Msgf("loading mapping file: %s", mapPath) + + data, err := os.ReadFile(mapPath) + if err != nil { + return err + } + + var newVals Values + err = toml.Unmarshal(data, &newVals) + if err != nil { + return err + } + + c.vals.Mappings.Entry = append(c.vals.Mappings.Entry, newVals.Mappings.Entry...) + + filesCounts++ + mappingsCount += len(newVals.Mappings.Entry) + } + + log.Info().Msgf("loaded %d mapping files, %d mappings", filesCounts, mappingsCount) + + return nil +} + +func (c *Instance) Mappings() []MappingsEntry { + c.mu.RLock() + defer c.mu.RUnlock() + return c.vals.Mappings.Entry +} diff --git a/pkg/service/mappings.go b/pkg/service/mappings.go index d9824597..3550dc29 100644 --- a/pkg/service/mappings.go +++ b/pkg/service/mappings.go @@ -22,6 +22,7 @@ along with Zaparoo Core. If not, see . package service import ( + "github.com/ZaparooProject/zaparoo-core/pkg/config" "github.com/ZaparooProject/zaparoo-core/pkg/service/tokens" "regexp" "strings" @@ -87,28 +88,74 @@ func checkMappingData(m database.Mapping, t tokens.Token) bool { return false } -func getMapping(db *database.Database, pl platforms.Platform, token tokens.Token) (string, bool) { +func isCfgRegex(s string) bool { + return len(s) > 2 && s[0] == '/' && s[len(s)-1] == '/' +} + +func mappingsFromConfig(cfg *config.Instance) []database.Mapping { + var mappings []database.Mapping + cfgMappings := cfg.Mappings() + + for _, m := range cfgMappings { + var dbm database.Mapping + dbm.Enabled = true + dbm.Override = m.ZapScript + + if m.TokenKey == "data" { + dbm.Type = database.MappingTypeData + } else if m.TokenKey == "value" { + dbm.Type = database.MappingTypeText + } else { + dbm.Type = database.MappingTypeUID + } + + if isCfgRegex(m.MatchPattern) { + dbm.Match = database.MatchTypeRegex + dbm.Pattern = m.MatchPattern[1 : len(m.MatchPattern)-1] + } else if strings.Contains(m.MatchPattern, "*") { + // TODO: this behaviour doesn't actually match "partial" + // the old behaviour will need to be migrated to this one + dbm.Match = database.MatchTypePartial + dbm.Pattern = strings.ReplaceAll(m.MatchPattern, "*", "") + } else { + dbm.Match = database.MatchTypeExact + dbm.Pattern = m.MatchPattern + } + + mappings = append(mappings, dbm) + } + + return mappings +} + +func getMapping(cfg *config.Instance, db *database.Database, pl platforms.Platform, token tokens.Token) (string, bool) { + // TODO: need a way to identify the source of a match so it can be + // reported and debugged by the user if there's issues + // check db mappings ms, err := db.GetEnabledMappings() if err != nil { log.Error().Err(err).Msgf("error getting db mappings") } + // load config mappings after + ms = append(ms, mappingsFromConfig(cfg)...) + for _, m := range ms { switch { case m.Type == database.MappingTypeUID: if checkMappingUid(m, token) { - log.Info().Msg("launching with db uid match override") + log.Info().Msg("launching with db/cfg uid match override") return m.Override, true } case m.Type == database.MappingTypeText: if checkMappingText(m, token) { - log.Info().Msg("launching with db text match override") + log.Info().Msg("launching with db/cfg text match override") return m.Override, true } case m.Type == database.MappingTypeData: if checkMappingData(m, token) { - log.Info().Msg("launching with db data match override") + log.Info().Msg("launching with db/cfg data match override") return m.Override, true } } diff --git a/pkg/service/service.go b/pkg/service/service.go index 8d8478c4..cca82ade 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -27,6 +27,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/pkg/service/playlists" "github.com/ZaparooProject/zaparoo-core/pkg/service/tokens" "os" + "path/filepath" "strings" "time" @@ -58,7 +59,7 @@ func launchToken( ) error { text := token.Text - mappingText, mapped := getMapping(db, platform, token) + mappingText, mapped := getMapping(cfg, db, platform, token) if mapped { log.Info().Msgf("found mapping: %s", mappingText) text = mappingText @@ -223,7 +224,11 @@ func Start( pl platforms.Platform, cfg *config.Instance, ) (func() error, error) { - dirs := []string{pl.DataDir(), pl.TempDir()} + dirs := []string{ + pl.DataDir(), + pl.TempDir(), + filepath.Join(pl.DataDir(), platforms.MappingsDir), + } for _, dir := range dirs { err := os.MkdirAll(dir, 0755) if err != nil { @@ -245,10 +250,17 @@ func Start( return nil, err } - log.Info().Msg("opening database") + log.Info().Msg("opening user database") db, err := database.Open(pl) if err != nil { - log.Error().Err(err).Msgf("error opening database") + log.Error().Err(err).Msgf("error opening user database") + return nil, err + } + + log.Info().Msg("loading mapping files") + err = cfg.LoadMappings(filepath.Join(pl.DataDir(), platforms.MappingsDir)) + if err != nil { + log.Error().Err(err).Msgf("error loading mapping files") return nil, err } @@ -256,27 +268,27 @@ func Start( go api.Start(pl, cfg, st, itq, db, ns) if !pl.LaunchingEnabled() { - log.Warn().Msg("launching disabled") + log.Warn().Msg("launching disabled by user") st.DisableLauncher() } log.Info().Msg("starting reader manager") go readerManager(pl, cfg, st, itq, lsq) - log.Info().Msg("starting token queue manager") + log.Info().Msg("starting input token queue manager") go processTokenQueue(pl, cfg, st, itq, db, lsq, plq) log.Info().Msg("running platform post start") err = pl.StartPost(cfg, st.Notifications) if err != nil { - log.Error().Err(err).Msg("platform start pre error") + log.Error().Err(err).Msg("platform post start error") return nil, err } return func() error { err = pl.Stop() if err != nil { - log.Warn().Msgf("error stopping pl: %s", err) + log.Warn().Msgf("error stopping platform: %s", err) } st.StopService() close(plq) diff --git a/pkg/service/state/state.go b/pkg/service/state/state.go index 14eb6422..4eb6393f 100644 --- a/pkg/service/state/state.go +++ b/pkg/service/state/state.go @@ -86,28 +86,14 @@ func (s *State) ShouldStopService() bool { func (s *State) DisableLauncher() { s.mu.Lock() + defer s.mu.Unlock() s.disableLauncher = true - if err := s.platform.SetLaunching(false); err != nil { - log.Error().Msgf("cannot create disable launch file: %s", err) - } - s.Notifications <- models.Notification{ - Method: models.TokensLaunching, - Params: false, - } - s.mu.Unlock() } func (s *State) EnableLauncher() { s.mu.Lock() + defer s.mu.Unlock() s.disableLauncher = false - if err := s.platform.SetLaunching(true); err != nil { - log.Error().Msgf("cannot remove disable launch file: %s", err) - } - s.Notifications <- models.Notification{ - Method: models.TokensLaunching, - Params: true, - } - s.mu.Unlock() } func (s *State) IsLauncherDisabled() bool {