diff --git a/.gitignore b/.gitignore index 9dc6f71..3611cd0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ redactedhook/ __debug_bin .RedactedHook .DS_Store + +## tests +internal/config/testconfig_updated.toml diff --git a/README.md b/README.md index 4211cd5..d283c7a 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,29 @@ # RedactedHook -RedactedHook is a webhook companion service for [autobrr](https://github.com/autobrr/autobrr) designed to check the names of uploaders, your ratio, and record labels associated with torrents on **Redacted** and **Orpheus**. It provides a simple and efficient way to validate if uploaders are blacklisted or whitelisted, to stop racing in case your ratio falls below a certain point, and to verify if a torrent's record label matches against a specified list. +RedactedHook is a webhook companion service for [autobrr](https://github.com/autobrr/autobrr) designed to check the names of uploaders, your ratio, torrent size and record labels associated with torrents on **Redacted** and **Orpheus**. It provides a simple and efficient way to validate if uploaders are blacklisted or whitelisted, to stop racing in case your ratio falls below a certain point, and to verify if a torrent's record label matches against a specified list. ## Table of Contents - [Features](#features) - [Getting Started](#getting-started) - - [Warning](#warning) - - [Installation](#installation) - - [Docker](#docker) - - [Docker Compose](#docker-compose) - - [Using precompiled binaries](#using-precompiled-binaries) - - [Building from source](#building-from-source) +- [Warning](#warning) +- [Installation](#installation) + - [Docker](#docker) + - [Docker Compose](#docker-compose) + - [Using precompiled binaries](#using-precompiled-binaries) + - [Building from source](#building-from-source) - [Usage](#usage) - - [Config](#config) - - [Authorization](#authorization) - - [Payload](#payload) + - [Commands](#commands) +- [Config](#config) +- [Authorization](#authorization) +- [Payload](#payload) ## Features - Verify if an uploader's name is on a provided whitelist or blacklist. - Check for record labels. Useful for grabbing torrents from a specific record label. - Check if a user's ratio meets a specified minimum value. -- Check the torrentSize (Useful for not hitting the API from both autobrr and redactedhook) +- Check the torrentSize (Useful for not hitting the API from both autobrr and redactedhook). - Easy to integrate with other applications via webhook. - Rate-limited to comply with tracker API request policies. - With a 5-minute data cache to reduce frequent API calls for the same data. @@ -31,22 +32,22 @@ It was made with [autobrr](https://github.com/autobrr/autobrr) in mind. ## Getting Started -### Warning +## Warning > \[!IMPORTANT] > > Remember that autobrr also checks the RED/OPS API if you have min/max sizes set. This will result in you hitting the API 2x. > So for your own good, **only** set size checks in RedactedHook. -### Installation +## Installation -#### Docker +### Docker ```bash docker pull ghcr.io/s0up4200/redactedhook:latest ``` -#### Docker Compose +### Docker Compose ```docker services: @@ -61,8 +62,9 @@ services: #cap_drop: # - ALL environment: - - REDACTEDHOOK__HOST=0.0.0.0 # binds to 127.0.0.1 by default - - REDACTEDHOOK__PORT=42135 # defaults to 42135 + #- REDACTEDHOOK__HOST=127.0.0.1 # Override the host from config.toml + #- REDACTEDHOOK__PORT=42135 # Override the port from config.toml + #- REDACTEDHOOK__API_TOKEN= # Override the api_token from config.toml - TZ=UTC ports: - "42135:42135" @@ -71,41 +73,37 @@ services: restart: unless-stopped ``` -#### Using precompiled binaries +### Using precompiled binaries Download the appropriate binary for your platform from the [releases](https://github.com/s0up4200/RedactedHook/releases/latest) page. -#### Building from source +### Building from source 1. Clone the repository: - ```bash - git clone https://github.com/s0up4200/RedactedHook.git - ``` +```bash +git clone https://github.com/s0up4200/RedactedHook.git +``` 2. Navigate to the project directory: - ```bash - cd RedactedHook - ``` +```bash +cd RedactedHook +``` 3. Build the project: - ```go - go build - ``` - - or - - ```shell - make build - ``` +```bash +go build +or +make build +``` 4. Run the compiled binary: - ```bash - ./bin/RedactedHook --config /path/to/config.toml # config flag not necessary if file is next to binary - ``` +```bash +./bin/RedactedHook --config /path/to/config.toml # config flag not necessary if file is next to binary +``` ## Usage @@ -120,17 +118,27 @@ Expected HTTP Status: 200 You can check ratio, uploader (whitelist and blacklist), minsize, maxsize, and record labels in a single request, or separately. -### Config +### Commands + +- `generate-apitoken`: Generate a new API token and print it. +- `create-config`: Create a default configuration file. +- `help`: Display this help message. -Most of `requestData` can be set in `config.toml` to reduce the payload from autobrr. +## Config -Config can be created with: `redactedhook create-config` +Most of requestData can be set in config.toml to reduce the payload from autobrr. + +### Example config.toml ```toml +[server] +host = "127.0.0.1" # Server host +port = 42135 # Server port + [authorization] api_token = "" # generate with "redactedhook generate-apitoken" # the api_token needs to be set as a header for the webhook to work -# eg. Header=X-API-Token asd987gsd98g7324kjh142kjh +# eg. Header=X-API-Token=asd987gsd98g7324kjh142kjh [indexer_keys] #red_apikey = "" # generate in user settings, needs torrent and user privileges @@ -164,7 +172,7 @@ maxage = 28 # Max age in days to keep a log file compress = false # Whether to compress old log files ``` -### Authorization +## Authorization API Token can be generated like this: `redactedhook generate-apitoken` @@ -182,9 +190,9 @@ curl -X POST \ http://127.0.0.1:42135/hook ``` -### Payload +## Payload -**The minimum required data to send with the webhook:** +The minimum required data to send with the webhook: ```json { @@ -193,26 +201,20 @@ curl -X POST \ } ``` -Everything else can be set in the `config.toml`, but you can set them in the webhook as well, if you want to filter by different things in different filters. - -`indexer` - `"{{ .Indexer | js }}"` this is the indexer that pushed the release within autobrr. - -`torrent_id` - `{{.TorrentID}}` this is the TorrentID of the pushed release within autobrr. - -`red_user_id` is the number in the URL when you visit your profile. - -`ops_user_id` is the number in the URL when you visit your profile. - -`red_apikey` is your Redacted API key. Needs user and torrents privileges. - -`ops_apikey` is your Orpheus API key. Needs user and torrents privileges. - -`record_labels` is a comma-separated list of record labels to check against. - -`minsize` is the minimum allowed size you want to grab. Eg. `100MB` +Everything else can be set in the config.toml, but you can set them in the webhook as well, if you want to filter by different things in different filters. -`maxsize` is the max allowed size you want to grab. Eg. `500MB` +- `indexer` - `"{{ .Indexer | js }}"` this is the indexer that pushed the release within autobrr. +- `torrent_id` - `{{.TorrentID}}` this is the TorrentID of the pushed release within autobrr. -`uploaders` is a comma-separated list of uploaders to check against. +### Additional Keys -`mode` is either blacklist or whitelist. If blacklist is used, the torrent will be stopped if the uploader is found in the list. If whitelist is used, the torrent will be stopped if the uploader is not found in the list. +- `red_user_id` is the number in the URL when you visit your profile. +- `ops_user_id` is the number in the URL when you visit your profile. +- `red_apikey` is your Redacted API key. Needs user and torrents privileges. +- `ops_apikey` is your Orpheus API key. Needs user and torrents privileges. +- `record_labels` is a comma-separated list of record labels to check against. +- `minsize` is the minimum allowed size you want to grab. Eg. 100MB +- `maxsize` is the max allowed size you want to grab. Eg. 500MB +- `uploaders` is a comma-separated list of uploaders to check against. +- `mode` is either blacklist or whitelist. If blacklist is used, the torrent will be stopped if the uploader is found in the list. If whitelist is used, the torrent will be stopped if the uploader is not found in the list. + ` diff --git a/cmd/redactedhook/main.go b/cmd/redactedhook/main.go index cc9853f..d904921 100644 --- a/cmd/redactedhook/main.go +++ b/cmd/redactedhook/main.go @@ -19,30 +19,39 @@ import ( ) var ( - version string - commit string - buildDate string + version = "dev" + commit = "none" + buildDate = "unknown" ) const ( - path = "/hook" - EnvServerAddress = "REDACTEDHOOK__HOST" - EnvServerPort = "REDACTEDHOOK__PORT" + path = "/hook" + tokenLength = 16 ) -func generateAPIToken(length int) string { - b := make([]byte, length) +func generateAPIToken() string { + b := make([]byte, tokenLength) if _, err := rand.Read(b); err != nil { log.Fatal().Err(err).Msg("Failed to generate API key") - return "" } apiKey := hex.EncodeToString(b) - // codeql-ignore-next-line: go/clear-text-logging-of-sensitive-information fmt.Fprintf(os.Stdout, "API Token: %v, copy and paste into your config.toml\n", apiKey) return apiKey } -func flagCommands() (string, bool) { +func printHelp() { + fmt.Println("Usage: redactedhook [options] [command]") + fmt.Println() + fmt.Println("Options:") + flag.PrintDefaults() + fmt.Println() + fmt.Println("Commands:") + fmt.Println(" generate-apitoken Generate a new API token and print it.") + fmt.Println(" create-config Create a default configuration file.") + fmt.Println(" help Display this help message.") +} + +func parseFlags() (string, bool) { var configPath string flag.StringVar(&configPath, "config", "config.toml", "Path to the configuration file") flag.Parse() @@ -50,71 +59,78 @@ func flagCommands() (string, bool) { if len(flag.Args()) > 0 { switch flag.Arg(0) { case "generate-apitoken": - return generateAPIToken(16), true + generateAPIToken() + return "", true case "create-config": - return config.CreateConfigFile(), true + config.CreateConfigFile() + return "", true + case "help": + printHelp() + return "", true default: - log.Fatal().Msgf("Unknown command: %s", flag.Arg(0)) + log.Fatal().Msgf("Unknown command: %s. Use 'redactedhook help' to see available commands.", flag.Arg(0)) } } return configPath, false } func getEnv(key, defaultValue string) string { - value := os.Getenv(key) - if value == "" { - return defaultValue + if value, exists := os.LookupEnv(key); exists { + return value } - return value + return defaultValue +} + +func initLogger() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02 15:04:05"}) } -func startHTTPServer(address, port string) { - server := &http.Server{Addr: address + ":" + port} +func startHTTPServer(address string) { + server := &http.Server{Addr: address} go func() { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatal().Err(err).Msg("Failed to start server") + log.Fatal().Err(err).Msg("HTTP server crashed") } }() - log.Info().Msgf("Starting server on %s", address+":"+port) + log.Info().Msgf("Starting server on %s", address) log.Info().Msgf("Version: %s, Commit: %s, Build Date: %s", version, commit, buildDate) - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - <-c + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + <-sig ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := server.Shutdown(ctx); err != nil { log.Error().Err(err).Msg("Server shutdown failed") - } else { - log.Info().Msg("Server gracefully stopped") } } func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02 15:04:05", NoColor: false}) + initLogger() - configPath, isCommandExecuted := flagCommands() + configPath, isCommandExecuted := parseFlags() if isCommandExecuted { return } config.InitConfig(configPath) - err := config.ValidateConfig() - if err != nil { + if err := config.ValidateConfig(); err != nil { log.Fatal().Err(err).Msg("Invalid configuration") - } else { - log.Debug().Msg("Configuration is valid.") } http.HandleFunc(path, api.WebhookHandler) - address := getEnv(EnvServerAddress, "127.0.0.1") - port := getEnv(EnvServerPort, "42135") + host := getEnv("REDACTEDHOOK__HOST", config.GetConfig().Server.Host) + port := getEnv("REDACTEDHOOK__PORT", fmt.Sprintf("%d", config.GetConfig().Server.Port)) + apiToken := getEnv("REDACTEDHOOK__API_TOKEN", config.GetConfig().Authorization.APIToken) + + config.GetConfig().Authorization.APIToken = apiToken + + address := fmt.Sprintf("%s:%s", host, port) - startHTTPServer(address, port) + startHTTPServer(address) } diff --git a/config.toml b/config.toml index a2c65b9..957ed19 100644 --- a/config.toml +++ b/config.toml @@ -1,7 +1,11 @@ +[server] +host = "127.0.0.1" # Server host +port = 42135 # Server port + [authorization] api_token = "" # generate with "redactedhook generate-apitoken" # the api_token needs to be set as a header for the webhook to work -# eg. Header: X-API-Token=asd987gsd98g7324kjh142kjh +# eg. Header: X-API-Token=aaa129cd1d66ed6fa567da2d07a5dd0e [indexer_keys] #red_apikey = "" # generate in user settings, needs torrent and user privileges diff --git a/docker-compose.yml b/docker-compose.yml index 5cb048e..744aa03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,8 +3,8 @@ services: container_name: redactedhook image: ghcr.io/s0up4200/redactedhook:latest #build: - #context: . - #dockerfile: Dockerfile + # context: . + # dockerfile: Dockerfile #runtime: runsc-ptrace #network_mode: bridge user: nobody @@ -14,8 +14,9 @@ services: cap_drop: - ALL environment: - - REDACTEDHOOK__HOST=0.0.0.0 # binds to 127.0.0.1 by default - - REDACTEDHOOK__PORT=42135 # defaults to 42135 + #- REDACTEDHOOK__HOST=127.0.0.1 # Override the host from config.toml + #- REDACTEDHOOK__PORT=42135 # Override the port from config.toml + #- REDACTEDHOOK__API_TOKEN= # Override the API token from config.toml - TZ=UTC ports: - 127.0.0.1:42135:42135 diff --git a/go.mod b/go.mod index a7c89c2..5dbdd39 100644 --- a/go.mod +++ b/go.mod @@ -8,16 +8,19 @@ require ( github.com/natefinch/lumberjack v2.0.0+incompatible github.com/rs/zerolog v1.29.0 github.com/spf13/viper v1.17.0 + github.com/stretchr/testify v1.8.4 golang.org/x/time v0.3.0 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 3e92b7d..2b0354d 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -13,6 +13,12 @@ func TestValidateRequestData(t *testing.T) { wantErr bool errMsg string }{ + { + name: "Empty request", + request: RequestData{}, + wantErr: true, + errMsg: "no indexer provided", + }, { name: "Valid request", request: RequestData{Indexer: "ops", TorrentID: 123, REDKey: "123456789012345678901234567890123456789012", OPSKey: "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012", MinRatio: 1.0, MinSize: 0, MaxSize: 10, Uploaders: "uploader1", RecordLabel: "label1", Mode: "blacklist"}, @@ -26,46 +32,58 @@ func TestValidateRequestData(t *testing.T) { errMsg: "invalid indexer: invalid", }, { - name: "Invalid torrent ID", - request: RequestData{Indexer: "ops", TorrentID: 1000000000}, - wantErr: true, - errMsg: "invalid torrent ID: 1000000000", + name: "Minimum valid torrent ID", + request: RequestData{Indexer: "ops", TorrentID: 1}, + wantErr: false, + errMsg: "", }, { - name: "REDKey too long", - request: RequestData{Indexer: "redacted", REDKey: "12345678901234567890212345678901234567890123"}, - wantErr: true, - errMsg: "REDKey is too long", + name: "Maximum valid torrent ID", + request: RequestData{Indexer: "ops", TorrentID: 999999999}, + wantErr: false, + errMsg: "", }, { - name: "OPSKey too long", - request: RequestData{Indexer: "ops", OPSKey: "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012109213091823098123091283"}, - wantErr: true, - errMsg: "OPSKey is too long", + name: "REDKey at maximum length", + request: RequestData{Indexer: "redacted", REDKey: "123456789012345678901234567890123456789012"}, + wantErr: false, + errMsg: "", }, { - name: "MinRatio out of range", - request: RequestData{Indexer: "ops", MinRatio: 1000}, - wantErr: true, - errMsg: "minRatio must be between 0 and 999.999", + name: "OPSKey at maximum length", + request: RequestData{Indexer: "ops", OPSKey: "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345"}, + wantErr: false, + errMsg: "", }, { - name: "MinSize greater than MaxSize", - request: RequestData{Indexer: "ops", MinSize: 11, MaxSize: 10}, - wantErr: true, - errMsg: "minsize cannot be greater than maxsize", + name: "MinRatio at lower boundary", + request: RequestData{Indexer: "ops", MinRatio: 0}, + wantErr: false, + errMsg: "", }, { - name: "Invalid RecordLabel", - request: RequestData{Indexer: "ops", RecordLabel: "label#1"}, - wantErr: true, - errMsg: "recordLabels field should only contain alphanumeric characters, spaces, and safe special characters", + name: "MinRatio at upper boundary", + request: RequestData{Indexer: "ops", MinRatio: 999.999}, + wantErr: false, + errMsg: "", + }, + { + name: "Valid RecordLabel with special characters", + request: RequestData{Indexer: "ops", RecordLabel: "label1 & label2 - label3"}, + wantErr: false, + errMsg: "", }, { - name: "Invalid Mode with Uploaders", - request: RequestData{Indexer: "ops", Uploaders: "uploader1", Mode: "invalid_mode"}, + name: "Empty mode with uploaders", + request: RequestData{Indexer: "ops", Uploaders: "uploader1"}, wantErr: true, - errMsg: "mode must be either 'whitelist' or 'blacklist', got 'invalid_mode'", + errMsg: "mode must be either 'whitelist' or 'blacklist', got ''", + }, + { + name: "Empty RecordLabel field", + request: RequestData{Indexer: "ops", RecordLabel: ""}, + wantErr: false, + errMsg: "", }, } diff --git a/internal/api/cache.go b/internal/api/cache.go index 05b934c..98a2d3c 100644 --- a/internal/api/cache.go +++ b/internal/api/cache.go @@ -1,27 +1,70 @@ package api import ( + "sync" "time" "github.com/rs/zerolog/log" ) -var cache = make(map[string]CacheItem) // keyed by indexer +const ( + cacheExpiryDuration = 5 * time.Minute + cacheCleanupInterval = 10 * time.Minute +) + +type CacheItem struct { + Data *ResponseData + LastFetched time.Time +} + +var ( + cache = make(map[string]CacheItem) + cacheLock sync.RWMutex +) + +func init() { + // Start a background goroutine to periodically clean up expired cache entries. + go startCacheCleanup() +} -// stores the responseData in cache with the specified cacheKey and updates the LastFetched timestamp. func cacheResponseData(cacheKey string, responseData *ResponseData) { + cacheLock.Lock() + defer cacheLock.Unlock() cache[cacheKey] = CacheItem{ Data: responseData, LastFetched: time.Now(), } } -// checks if there is cached data for a given cache key and indexer, -// and returns the cached data if it exists and is not expired. -func checkCache(cacheKey string, indexer string) (*ResponseData, bool) { - if cached, ok := cache[cacheKey]; ok && time.Since(cached.LastFetched) < 5*time.Minute { - log.Trace().Msgf("[%s] Using cached data for %s", indexer, cacheKey) - return cached.Data, true +func checkCache(cacheKey, indexer string) (*ResponseData, bool) { + cacheLock.RLock() + defer cacheLock.RUnlock() + + if cached, ok := cache[cacheKey]; ok { + if time.Since(cached.LastFetched) < cacheExpiryDuration { + log.Trace().Msgf("[%s] Using cached data for %s", indexer, cacheKey) + return cached.Data, true + } } return nil, false } + +func startCacheCleanup() { + for { + time.Sleep(cacheCleanupInterval) + removeExpiredCacheEntries() + } +} + +func removeExpiredCacheEntries() { + cacheLock.Lock() + defer cacheLock.Unlock() + + now := time.Now() + for key, item := range cache { + if now.Sub(item.LastFetched) >= cacheExpiryDuration { + delete(cache, key) + //log.Trace().Msgf("Removed expired cache entry for %s", key) + } + } +} diff --git a/internal/api/config.go b/internal/api/config.go index 2ff5768..aec1476 100644 --- a/internal/api/config.go +++ b/internal/api/config.go @@ -1,43 +1,49 @@ package api import ( + "github.com/inhies/go-bytesize" "github.com/s0up4200/redactedhook/internal/config" ) -// checks if certain fields in the requestData struct are empty or zero, -// and if so, it populates them with values from the cfg struct. +// fallbackToConfig prioritizes webhook data over config data. +// If webhook data is present, it overwrites the existing config data. func fallbackToConfig(requestData *RequestData) { - config := config.GetConfig() + cfg := config.GetConfig() - // Directly assign values from the global config if they are not set in requestData - if requestData.REDUserID == 0 { - requestData.REDUserID = config.UserIDs.REDUserID + // Helper functions to set fields, prioritizing webhook data if present + setInt := func(webhookField *int, configValue int) { + if *webhookField == 0 { + *webhookField = configValue + } } - if requestData.OPSUserID == 0 { - requestData.OPSUserID = config.UserIDs.OPSUserID - } - if requestData.REDKey == "" { - requestData.REDKey = config.IndexerKeys.REDKey - } - if requestData.OPSKey == "" { - requestData.OPSKey = config.IndexerKeys.OPSKey - } - if requestData.MinRatio == 0 { - requestData.MinRatio = config.Ratio.MinRatio - } - if requestData.MinSize == 0 { - requestData.MinSize = config.ParsedSizes.MinSize - } - if requestData.MaxSize == 0 { - requestData.MaxSize = config.ParsedSizes.MaxSize - } - if requestData.Uploaders == "" { - requestData.Uploaders = config.Uploaders.Uploaders + + setFloat64 := func(webhookField *float64, configValue float64) { + if *webhookField == 0 { + *webhookField = configValue + } } - if requestData.Mode == "" { - requestData.Mode = config.Uploaders.Mode + + setByteSize := func(webhookField *bytesize.ByteSize, configValue bytesize.ByteSize) { + if *webhookField == 0 { + *webhookField = configValue + } } - if requestData.RecordLabel == "" { - requestData.RecordLabel = config.RecordLabels.RecordLabels + + setString := func(webhookField *string, configValue string) { + if *webhookField == "" { + *webhookField = configValue + } } + + // Check and set the fields, ensuring webhook data takes priority if present + setInt(&requestData.REDUserID, cfg.UserIDs.REDUserID) + setInt(&requestData.OPSUserID, cfg.UserIDs.OPSUserID) + setString(&requestData.REDKey, cfg.IndexerKeys.REDKey) + setString(&requestData.OPSKey, cfg.IndexerKeys.OPSKey) + setFloat64(&requestData.MinRatio, cfg.Ratio.MinRatio) + setByteSize(&requestData.MinSize, cfg.ParsedSizes.MinSize) + setByteSize(&requestData.MaxSize, cfg.ParsedSizes.MaxSize) + setString(&requestData.Uploaders, cfg.Uploaders.Uploaders) + setString(&requestData.Mode, cfg.Uploaders.Mode) + setString(&requestData.RecordLabel, cfg.RecordLabels.RecordLabels) } diff --git a/internal/api/constants.go b/internal/api/constants.go deleted file mode 100644 index 8489a47..0000000 --- a/internal/api/constants.go +++ /dev/null @@ -1,15 +0,0 @@ -package api - -import "net/http" - -const ( - APIEndpointBaseRedacted = "https://redacted.ch/ajax.php" - APIEndpointBaseOrpheus = "https://orpheus.network/ajax.php" -) - -const ( - StatusUploaderNotAllowed = http.StatusIMUsed + 1 - StatusLabelNotAllowed = http.StatusIMUsed + 2 - StatusSizeNotAllowed = http.StatusIMUsed + 3 - StatusRatioNotAllowed = http.StatusIMUsed -) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index c16e582..27a8b13 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1,44 +1,57 @@ package api import ( - "fmt" + "errors" "net/http" - "strings" "github.com/rs/zerolog/log" "github.com/s0up4200/redactedhook/internal/config" ) -// handles webhooks: auth, decode payload, validate, respond 200. -func WebhookHandler(w http.ResponseWriter, r *http.Request) { - var requestData RequestData +const ( + StatusUploaderNotAllowed = http.StatusIMUsed + 1 + StatusLabelNotAllowed = http.StatusIMUsed + 2 + StatusSizeNotAllowed = http.StatusIMUsed + 3 + StatusRatioNotAllowed = http.StatusIMUsed +) +const ( + ErrInvalidJSONResponse = "invalid JSON response" + ErrRecordLabelNotFound = "record label not found" + ErrRecordLabelNotAllowed = "record label not allowed" + ErrUploaderNotAllowed = "uploader is not allowed" + ErrSizeNotAllowed = "torrent size is outside the requested size range" + ErrRatioBelowMinimum = "returned ratio is below minimum requirement" +) + +func WebhookHandler(w http.ResponseWriter, r *http.Request) { cfg := config.GetConfig() + var requestData RequestData fallbackToConfig(&requestData) if err := verifyAPIKey(r.Header.Get("X-API-Token"), cfg.Authorization.APIToken); err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + writeHTTPError(w, err, http.StatusUnauthorized) return } if err := validateRequestMethod(r.Method); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeHTTPError(w, err, http.StatusBadRequest) return } if err := decodeJSONPayload(r, &requestData); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeHTTPError(w, err, http.StatusBadRequest) return } defer r.Body.Close() if err := validateIndexer(requestData.Indexer); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeHTTPError(w, err, http.StatusBadRequest) return } if err := validateRequestData(&requestData); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeHTTPError(w, err, http.StatusBadRequest) return } @@ -46,68 +59,80 @@ func WebhookHandler(w http.ResponseWriter, r *http.Request) { apiBase, err := determineAPIBase(requestData.Indexer) if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) + writeHTTPError(w, err, http.StatusNotFound) return } reqHeader := make(http.Header) setAuthorizationHeader(&reqHeader, &requestData) - // Call hooks + if hookError := runHooks(&requestData, apiBase); hookError != nil { + handleErrors(w, hookError) + return + } - if requestData.TorrentID != 0 && (requestData.MinSize != 0 || requestData.MaxSize != 0) { - if err := hookSize(&requestData, apiBase); err != nil { - handleErrors(w, err, StatusSizeNotAllowed) - return + w.WriteHeader(http.StatusOK) + log.Info().Msgf("[%s] Conditions met, responding with status 200", requestData.Indexer) +} +func runHooks(requestData *RequestData, apiBase string) error { + if requestData.TorrentID != 0 && (requestData.MinSize != 0 || requestData.MaxSize != 0) { + if err := hookSize(requestData, apiBase); err != nil { + return errors.New(ErrSizeNotAllowed) } } if requestData.TorrentID != 0 && requestData.Uploaders != "" { - if err := hookUploader(&requestData, apiBase); err != nil { - handleErrors(w, err, StatusUploaderNotAllowed) - return - + if err := hookUploader(requestData, apiBase); err != nil { + return errors.New(ErrUploaderNotAllowed) } } if requestData.TorrentID != 0 && requestData.RecordLabel != "" { - if err := hookRecordLabel(&requestData, apiBase); err != nil { - handleErrors(w, err, StatusLabelNotAllowed) - return + if err := hookRecordLabel(requestData, apiBase); err != nil { + return errors.New(ErrRecordLabelNotAllowed) } } if requestData.MinRatio != 0 { - if err := hookRatio(&requestData, apiBase); err != nil { - handleErrors(w, err, StatusRatioNotAllowed) - return - + if err := hookRatio(requestData, apiBase); err != nil { + return errors.New(ErrRatioBelowMinimum) } } - w.WriteHeader(http.StatusOK) // HTTP status code 200 - log.Info().Msgf("[%s] Conditions met, responding with status 200", requestData.Indexer) + return nil } -func handleErrors(w http.ResponseWriter, err error, defaultStatusCode int) { - if strings.Contains(err.Error(), "invalid JSON response") { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return // We're done here, no need to continue. +func writeHTTPError(w http.ResponseWriter, err error, statusCode int) { + http.Error(w, err.Error(), statusCode) +} + +func handleErrors(w http.ResponseWriter, err error) { + if err == nil { + return } - if strings.HasPrefix(err.Error(), "HTTP error:") { - // Extract the status code from the error message - var statusCode int - _, scanErr := fmt.Sscanf(err.Error(), "HTTP error: %d", &statusCode) - if scanErr == nil && statusCode != 0 { - http.Error(w, err.Error(), statusCode) - return // We're done here, too. - } - // Fallback to internal server error if status code extraction fails + switch err.Error() { + case ErrInvalidJSONResponse: + http.Error(w, ErrInvalidJSONResponse, http.StatusInternalServerError) + + case ErrRecordLabelNotFound: + http.Error(w, ErrRecordLabelNotFound, http.StatusBadRequest) + + case ErrRecordLabelNotAllowed: + http.Error(w, ErrRecordLabelNotAllowed, http.StatusForbidden) + + case ErrUploaderNotAllowed: + http.Error(w, ErrUploaderNotAllowed, http.StatusForbidden) + + case ErrSizeNotAllowed: + http.Error(w, ErrSizeNotAllowed, http.StatusBadRequest) + + case ErrRatioBelowMinimum: + http.Error(w, ErrRatioBelowMinimum, http.StatusForbidden) + + default: + log.Error().Err(err).Msg("Unhandled error") http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return // Still done. } - - http.Error(w, err.Error(), defaultStatusCode) } diff --git a/internal/api/hooks.go b/internal/api/hooks.go index d2744ba..c3ff0aa 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -9,31 +9,60 @@ import ( "github.com/rs/zerolog/log" ) -// checks if the uploader is allowed based on the requestData. -func hookUploader(requestData *RequestData, apiBase string) error { +type RequestData struct { + REDUserID int `json:"red_user_id,omitempty"` + OPSUserID int `json:"ops_user_id,omitempty"` + TorrentID int `json:"torrent_id,omitempty"` + REDKey string `json:"red_apikey,omitempty"` + OPSKey string `json:"ops_apikey,omitempty"` + MinRatio float64 `json:"minratio,omitempty"` + MinSize bytesize.ByteSize `json:"minsize,omitempty"` + MaxSize bytesize.ByteSize `json:"maxsize,omitempty"` + Uploaders string `json:"uploaders,omitempty"` + RecordLabel string `json:"record_labels,omitempty"` + Mode string `json:"mode,omitempty"` + Indexer string `json:"indexer"` +} + +type ResponseData struct { + Status string `json:"status"` + Error string `json:"error"` + Response struct { + Username string `json:"username"` + Stats struct { + Ratio float64 `json:"ratio"` + } `json:"stats"` + Group struct { + Name string `json:"name"` + MusicInfo struct { + Artists []struct { + ID int `json:"id"` + Name string `json:"name"` + } `json:"artists"` + } `json:"musicInfo"` + } `json:"group"` + Torrent *struct { + Username string `json:"username"` + Size int64 `json:"size"` + RecordLabel string `json:"remasterRecordLabel"` + ReleaseName string `json:"filePath"` + CatalogueNumber string `json:"remasterCatalogueNumber"` + } `json:"torrent"` + } `json:"response"` +} +func hookUploader(requestData *RequestData, apiBase string) error { torrentData, err := fetchResponseData(requestData, requestData.TorrentID, "torrent", apiBase) if err != nil { return err } username := torrentData.Response.Torrent.Username - usernames := strings.Split(requestData.Uploaders, ",") - for i, uname := range usernames { - usernames[i] = strings.TrimSpace(uname) - } + usernames := parseAndTrimList(requestData.Uploaders) - usernamesStr := strings.Join(usernames, ", ") - log.Trace().Msgf("[%s] Requested uploaders [%s]: %s", requestData.Indexer, requestData.Mode, usernamesStr) - - isListed := false - for _, uname := range usernames { - if uname == username { - isListed = true - break - } - } + log.Trace().Msgf("[%s] Requested uploaders [%s]: %s", requestData.Indexer, requestData.Mode, strings.Join(usernames, ", ")) + isListed := stringInSlice(username, usernames) if (requestData.Mode == "blacklist" && isListed) || (requestData.Mode == "whitelist" && !isListed) { log.Debug().Msgf("[%s] Uploader (%s) is not allowed", requestData.Indexer, username) return fmt.Errorf("uploader is not allowed") @@ -41,8 +70,10 @@ func hookUploader(requestData *RequestData, apiBase string) error { return nil } -// checks if the record label is allowed based on the requestData. func hookRecordLabel(requestData *RequestData, apiBase string) error { + requestedRecordLabels := parseAndTrimList(requestData.RecordLabel) + log.Trace().Msgf("[%s] Requested record labels: [%s]", requestData.Indexer, strings.Join(requestedRecordLabels, ", ")) + torrentData, err := fetchResponseData(requestData, requestData.TorrentID, "torrent", apiBase) if err != nil { return err @@ -51,25 +82,19 @@ func hookRecordLabel(requestData *RequestData, apiBase string) error { recordLabel := strings.ToLower(strings.TrimSpace(html.UnescapeString(torrentData.Response.Torrent.RecordLabel))) name := torrentData.Response.Group.Name - requestedRecordLabels := normalizeLabels(strings.Split(requestData.RecordLabel, ",")) if recordLabel == "" { log.Debug().Msgf("[%s] No record label found for release: %s", requestData.Indexer, name) - return fmt.Errorf("record label not allowed") + return fmt.Errorf("record label not found") } - recordLabelsStr := strings.Join(requestedRecordLabels, ", ") - log.Trace().Msgf("[%s] Requested record labels: [%s]", requestData.Indexer, recordLabelsStr) - - isRecordLabelPresent := contains(requestedRecordLabels, recordLabel) - if !isRecordLabelPresent { - log.Debug().Msgf("[%s] The record label '%s' is not included in the requested record labels: [%s]", requestData.Indexer, recordLabel, recordLabelsStr) + if !stringInSlice(recordLabel, requestedRecordLabels) { + log.Debug().Msgf("[%s] The record label '%s' is not included in the requested record labels: [%s]", requestData.Indexer, recordLabel, strings.Join(requestedRecordLabels, ", ")) return fmt.Errorf("record label not allowed") } return nil } -// checks if the torrent size is within the allowed range based on the requestData. func hookSize(requestData *RequestData, apiBase string) error { torrentData, err := fetchResponseData(requestData, requestData.TorrentID, "torrent", apiBase) if err != nil { @@ -77,35 +102,27 @@ func hookSize(requestData *RequestData, apiBase string) error { } torrentSize := bytesize.ByteSize(torrentData.Response.Torrent.Size) - minSize := bytesize.ByteSize(requestData.MinSize) - maxSize := bytesize.ByteSize(requestData.MaxSize) log.Trace().Msgf("[%s] Torrent size: %s, Requested size range: %s - %s", requestData.Indexer, torrentSize, requestData.MinSize, requestData.MaxSize) - if (requestData.MinSize != 0 && torrentSize < minSize) || - (requestData.MaxSize != 0 && torrentSize > maxSize) { - log.Debug().Msgf("[%s] Torrent size %s is outside the requested size range: %s to %s", requestData.Indexer, torrentSize, minSize, maxSize) + if (requestData.MinSize != 0 && torrentSize < requestData.MinSize) || + (requestData.MaxSize != 0 && torrentSize > requestData.MaxSize) { + log.Debug().Msgf("[%s] Torrent size %s is outside the requested size range: %s to %s", requestData.Indexer, torrentSize, requestData.MinSize, requestData.MaxSize) return fmt.Errorf("torrent size is outside the requested size range") } return nil - } -// checks if the user ratio is above the minimum requirement based on the requestData. func hookRatio(requestData *RequestData, apiBase string) error { - userID := requestData.REDUserID + userID := getUserID(requestData) minRatio := requestData.MinRatio - if requestData.Indexer == "ops" { - userID = requestData.OPSUserID - } - // Check for incomplete configuration if userID == 0 || minRatio == 0 { if userID != 0 || minRatio != 0 { log.Warn().Msgf("[%s] Incomplete ratio check configuration: userID or minRatio is missing.", requestData.Indexer) } - return nil // Exit early if either is zero, as the check cannot proceed + return nil } userData, err := fetchResponseData(requestData, userID, "user", apiBase) @@ -125,3 +142,27 @@ func hookRatio(requestData *RequestData, apiBase string) error { return nil } + +func parseAndTrimList(list string) []string { + items := strings.Split(list, ",") + for i, item := range items { + items[i] = strings.TrimSpace(item) + } + return items +} + +func stringInSlice(str string, list []string) bool { + for _, item := range list { + if item == str { + return true + } + } + return false +} + +func getUserID(requestData *RequestData) int { + if requestData.Indexer == "ops" { + return requestData.OPSUserID + } + return requestData.REDUserID +} diff --git a/internal/api/limiters.go b/internal/api/limiters.go index 1a8eb0d..a1ed1ab 100644 --- a/internal/api/limiters.go +++ b/internal/api/limiters.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "time" "github.com/rs/zerolog/log" @@ -17,15 +18,15 @@ func init() { orpheusLimiter = rate.NewLimiter(rate.Every(10*time.Second), 5) } -// returns a rate limiter based on the provided indexer string. -func getLimiter(indexer string) *rate.Limiter { +func getLimiter(indexer string) (*rate.Limiter, error) { switch indexer { case "redacted": - return redactedLimiter + return redactedLimiter, nil case "ops": - return orpheusLimiter + return orpheusLimiter, nil default: - log.Error().Msgf("Invalid indexer: %s", indexer) - return nil + err := fmt.Errorf("invalid indexer: %s", indexer) + log.Error().Err(err).Msg("Failed to get rate limiter") + return nil, err } } diff --git a/internal/api/requests.go b/internal/api/requests.go index cf7397a..9560388 100644 --- a/internal/api/requests.go +++ b/internal/api/requests.go @@ -8,32 +8,33 @@ import ( "html" "io" "net/http" - "strings" "time" "github.com/rs/zerolog/log" "golang.org/x/time/rate" ) -// sends an HTTP GET request to an endpoint with an API key, applies a rate limiter, and unmarshals the response JSON into a target object. +const ( + APIEndpointBaseRedacted = "https://redacted.ch/ajax.php" + APIEndpointBaseOrpheus = "https://orpheus.network/ajax.php" +) + func makeRequest(endpoint, apiKey string, limiter *rate.Limiter, indexer string, target interface{}) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - if !limiter.Allow() { - log.Warn().Msgf("%s: Too many requests", indexer) - return fmt.Errorf("rate limit exceeded for %s", indexer) + if err := limiter.Wait(ctx); err != nil { + log.Warn().Msgf("%s: Rate limit exceeded", indexer) + return fmt.Errorf("rate limit exceeded for %s: %w", indexer, err) } - req, err := http.NewRequest("GET", endpoint, nil) + req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) if err != nil { log.Error().Err(err).Msg("Error creating HTTP request") return err } req.Header.Set("Authorization", apiKey) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - req = req.WithContext(ctx) - resp, err := http.DefaultClient.Do(req) if err != nil { log.Error().Err(err).Msg("Error executing HTTP request") @@ -41,14 +42,6 @@ func makeRequest(endpoint, apiKey string, limiter *rate.Limiter, indexer string, } defer resp.Body.Close() - //dump, err := httputil.DumpResponse(resp, true) - //if err != nil { - // log.Error().Err(err).Msg("Error dumping the response") - // return err - //} - // - //fmt.Printf("HTTP Response:\n%s\n", dump) - if resp.StatusCode >= 400 { errMsg := fmt.Sprintf("HTTP error: %d from %s", resp.StatusCode, endpoint) log.Error().Msg(errMsg) @@ -57,7 +50,7 @@ func makeRequest(endpoint, apiKey string, limiter *rate.Limiter, indexer string, respBody, err := io.ReadAll(resp.Body) if err != nil { - log.Error().Err(err).Msg("fetchAPI error") + log.Error().Err(err).Msg("Error reading response body") return err } @@ -73,24 +66,21 @@ func makeRequest(endpoint, apiKey string, limiter *rate.Limiter, indexer string, } if responseData.Status != "success" { - // log.Warn().Msgf("API error from %s: %s", indexer, responseData.Error) return fmt.Errorf("API error from %s: %s", indexer, responseData.Error) } return nil } -// initiates an API request with the given parameters and returns the response data or an error. -func initiateAPIRequest(id int, action string, apiKey, apiBase, indexer string) (*ResponseData, error) { - limiter := getLimiter(indexer) - if limiter == nil { - return nil, fmt.Errorf("could not get rate limiter for indexer: %s", indexer) +func initiateAPIRequest(id int, action, apiKey, apiBase, indexer string) (*ResponseData, error) { + limiter, err := getLimiter(indexer) + if err != nil { + return nil, fmt.Errorf("could not get rate limiter for indexer: %s, %w", indexer, err) } endpoint := fmt.Sprintf("%s?action=%s&id=%d", apiBase, action, id) responseData := &ResponseData{} - err := makeRequest(endpoint, apiKey, limiter, indexer, responseData) - if err != nil { + if err := makeRequest(endpoint, apiKey, limiter, indexer, responseData); err != nil { return nil, err } @@ -103,13 +93,10 @@ func initiateAPIRequest(id int, action string, apiKey, apiBase, indexer string) return responseData, nil } -// fetches response data from an API, checks the cache first, and caches the response data for future use. -func fetchResponseData(requestData *RequestData, id int, action string, apiBase string) (*ResponseData, error) { - - // Check cache first +// fetchResponseData fetches response data from an API, checks the cache first, and caches the response data for future use. +func fetchResponseData(requestData *RequestData, id int, action, apiBase string) (*ResponseData, error) { cacheKey := fmt.Sprintf("%sID %d", action, id) - cachedData, found := checkCache(cacheKey, requestData.Indexer) - if found { + if cachedData, found := checkCache(cacheKey, requestData.Indexer); found { return cachedData, nil } @@ -120,21 +107,15 @@ func fetchResponseData(requestData *RequestData, id int, action string, apiBase responseData, err := initiateAPIRequest(id, action, apiKey, apiBase, requestData.Indexer) if err != nil { - if strings.Contains(err.Error(), "rate limit exceeded") { - return nil, err - } wrappedErr := fmt.Errorf("error fetching %s data for ID %d: %w", action, id, err) log.Error().Err(wrappedErr).Msg("Data fetching") return nil, wrappedErr } - // Cache the response data cacheResponseData(cacheKey, responseData) - return responseData, nil } -// determines the API base endpoint based on the provided indexer. func determineAPIBase(indexer string) (string, error) { switch indexer { case "redacted": @@ -142,6 +123,17 @@ func determineAPIBase(indexer string) (string, error) { case "ops": return APIEndpointBaseOrpheus, nil default: - return "", fmt.Errorf("invalid path") + return "", fmt.Errorf("invalid indexer: %s", indexer) + } +} + +func getAPIKey(requestData *RequestData) (string, error) { + switch requestData.Indexer { + case "redacted": + return requestData.REDKey, nil + case "ops": + return requestData.OPSKey, nil + default: + return "", errors.New("invalid indexer") } } diff --git a/internal/api/types.go b/internal/api/types.go deleted file mode 100644 index a7294fc..0000000 --- a/internal/api/types.go +++ /dev/null @@ -1,54 +0,0 @@ -package api - -import ( - "time" - - "github.com/inhies/go-bytesize" -) - -type CacheItem struct { - Data *ResponseData - LastFetched time.Time -} - -type RequestData struct { - REDUserID int `json:"red_user_id,omitempty"` - OPSUserID int `json:"ops_user_id,omitempty"` - TorrentID int `json:"torrent_id,omitempty"` - REDKey string `json:"red_apikey,omitempty"` - OPSKey string `json:"ops_apikey,omitempty"` - MinRatio float64 `json:"minratio,omitempty"` - MinSize bytesize.ByteSize `json:"minsize,omitempty"` - MaxSize bytesize.ByteSize `json:"maxsize,omitempty"` - Uploaders string `json:"uploaders,omitempty"` - RecordLabel string `json:"record_labels,omitempty"` - Mode string `json:"mode,omitempty"` - Indexer string `json:"indexer"` -} - -type ResponseData struct { - Status string `json:"status"` - Error string `json:"error"` - Response struct { - Username string `json:"username"` - Stats struct { - Ratio float64 `json:"ratio"` - } `json:"stats"` - Group struct { - Name string `json:"name"` - MusicInfo struct { - Artists []struct { - ID int `json:"id"` - Name string `json:"name"` - } `json:"artists"` - } `json:"musicInfo"` - } `json:"group"` - Torrent *struct { - Username string `json:"username"` - Size int64 `json:"size"` - RecordLabel string `json:"remasterRecordLabel"` - ReleaseName string `json:"filePath"` - CatalogueNumber string `json:"remasterCatalogueNumber"` - } `json:"torrent"` - } `json:"response"` -} diff --git a/internal/api/utils.go b/internal/api/utils.go index faaa105..26365e1 100644 --- a/internal/api/utils.go +++ b/internal/api/utils.go @@ -4,54 +4,30 @@ import ( "encoding/json" "fmt" "net/http" - "strings" -) -// takes a slice of strings and returns a new slice with all the labels -// converted to lowercase and trimmed of any leading or trailing whitespace. -func normalizeLabels(labels []string) []string { - normalized := make([]string, len(labels)) - for i, label := range labels { - normalized[i] = strings.ToLower(strings.TrimSpace(label)) - } - return normalized -} -func contains(slice []string, val string) bool { - for _, item := range slice { - if item == val { - return true - } - } - return false -} + "github.com/rs/zerolog/log" +) -// returns the appropriate API key based on the indexer specified in the `requestData` parameter. -func getAPIKey(requestData *RequestData) (string, error) { +func setAuthorizationHeader(reqHeader *http.Header, requestData *RequestData) error { + var apiKey string switch requestData.Indexer { case "redacted": - return requestData.REDKey, nil - case "ops": - return requestData.OPSKey, nil - default: - return "", fmt.Errorf("invalid indexer: %s", requestData.Indexer) - } -} - -// sets the Authorization header in an HTTP request header based on the indexer specified -func setAuthorizationHeader(reqHeader *http.Header, requestData *RequestData) { - var apiKey string - if requestData.Indexer == "redacted" { apiKey = requestData.REDKey - } else if requestData.Indexer == "ops" { + case "ops": apiKey = requestData.OPSKey + default: + err := fmt.Errorf("invalid indexer: %s", requestData.Indexer) + log.Error().Err(err).Msg("Failed to set authorization header") + return err } reqHeader.Set("Authorization", apiKey) + return nil } -// decodes a JSON payload from an HTTP request and stores it in a struct. func decodeJSONPayload(r *http.Request, requestData *RequestData) error { + defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(requestData); err != nil { - return fmt.Errorf("invalid JSON payload") + return fmt.Errorf("invalid JSON payload: %w", err) } return nil } diff --git a/internal/api/validation.go b/internal/api/validation.go index 26d3eb8..bff8c5b 100644 --- a/internal/api/validation.go +++ b/internal/api/validation.go @@ -9,15 +9,13 @@ import ( "github.com/rs/zerolog/log" ) -// verifyAPIKey checks if the provided API key matches the expected one. -func verifyAPIKey(headerAPIKey string, expectedAPIKey string) error { +func verifyAPIKey(headerAPIKey, expectedAPIKey string) error { if expectedAPIKey == "" || headerAPIKey != expectedAPIKey { return fmt.Errorf("invalid or missing API key") } return nil } -// validateRequestMethod ensures the request uses the POST method. func validateRequestMethod(method string) error { if method != http.MethodPost { return fmt.Errorf("only POST method is supported") @@ -25,29 +23,27 @@ func validateRequestMethod(method string) error { return nil } -// checks if the given `RequestData` object contains valid data and returns an error if any of the validations fail. func validateRequestData(requestData *RequestData) error { safeCharacterRegex := regexp.MustCompile(`^[\p{L}\p{N}\s&,-]+$`) - if requestData.Indexer != "ops" && requestData.Indexer != "redacted" { - errMsg := fmt.Sprintf("invalid indexer: %s", requestData.Indexer) - log.Debug().Msg(errMsg) - return fmt.Errorf(errMsg) + if err := validateIndexer(requestData.Indexer); err != nil { + log.Debug().Err(err).Msg("Validation error") + return err } - if requestData.TorrentID > 999999999 { + if requestData.TorrentID > 999_999_999 { errMsg := fmt.Sprintf("invalid torrent ID: %d", requestData.TorrentID) log.Debug().Msg(errMsg) return fmt.Errorf(errMsg) } - if requestData.REDKey != "" && len(requestData.REDKey) > 42 { + if len(requestData.REDKey) > 42 { errMsg := "REDKey is too long" log.Debug().Msg(errMsg) return fmt.Errorf(errMsg) } - if requestData.OPSKey != "" && len(requestData.OPSKey) > 120 { + if len(requestData.OPSKey) > 120 { errMsg := "OPSKey is too long" log.Debug().Msg(errMsg) return fmt.Errorf(errMsg) @@ -60,7 +56,7 @@ func validateRequestData(requestData *RequestData) error { } if requestData.MaxSize > 0 && requestData.MinSize > requestData.MaxSize { - errMsg := "minsize cannot be greater than maxsize" + errMsg := "minSize cannot be greater than maxSize" log.Debug().Msg(errMsg) return fmt.Errorf(errMsg) } @@ -88,7 +84,6 @@ func validateRequestData(requestData *RequestData) error { return nil } -// checks if a given indexer string is valid or not. func validateIndexer(indexer string) error { if indexer != "ops" && indexer != "redacted" { if indexer == "" { diff --git a/internal/config/config.go b/internal/config/config.go index 747b2bd..db86ca4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,10 +19,14 @@ func GetConfig() *Config { } func CreateConfigFile() string { - config := `[authorization] + config := `[server] +host = "127.0.0.1" # Server host +port = 42135 # Server port + +[authorization] api_token = "" # generate with "redactedhook generate-apitoken" # the api_token needs to be set as a header for the webhook to work -# eg. X-API-Token=asd987gsd98g7324kjh142kjh +# eg. Header: X-API-Token=aaa129cd1d66ed6fa567da2d07a5dd0e [indexer_keys] #red_apikey = "" # generate in user settings, needs torrent and user privileges @@ -56,10 +60,10 @@ maxage = 28 # Max age in days to keep a log file compress = false # Whether to compress old log files ` - err := os.WriteFile("config.toml", []byte(config), 0644) + err := os.WriteFile(defaultConfigFileName, []byte(config), 0644) if err != nil { log.Fatal().Err(err).Msg("Failed to write default configuration file") } fmt.Println("Configuration file 'config.toml' generated.") - return "config.toml" + return defaultConfigFileName } diff --git a/internal/config/config_loader.go b/internal/config/config_loader.go index 64919ae..7ead760 100644 --- a/internal/config/config_loader.go +++ b/internal/config/config_loader.go @@ -2,6 +2,8 @@ package config import ( "errors" + "fmt" + "os" "strings" "github.com/fsnotify/fsnotify" @@ -18,7 +20,6 @@ func InitConfig(configPath string) { } func setupViper(configFile string) { - // Set default values before reading the config file viper.SetDefault("userid.red_user_id", 0) viper.SetDefault("userid.ops_user_id", 0) viper.SetDefault("ratio.minratio", 0) @@ -28,17 +29,11 @@ func setupViper(configFile string) { viper.SetDefault("uploaders.mode", "") viper.SetDefault("record_labels.record_labels", "") - viper.SetConfigType(defaultConfigType) + viper.SetConfigType("toml") viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.SetConfigFile(configFile) - // Uncomment this if you want to ensure the config file exists - // and create it if it does not. - // if err := createConfigFileIfNotExist(configFile); err != nil { - // log.Fatal().Err(err).Msg("Failed to create or verify config file") - // } - if err := viper.ReadInConfig(); err != nil { log.Fatal().Err(err).Msg("Error reading config file") } @@ -55,26 +50,22 @@ func readAndUnmarshalConfig() { } func parseSizeCheck() { - // Parse MinSize minSizeStr := viper.GetString("sizecheck.minsize") if minSizeStr == "" { - config.ParsedSizes.MinSize = 0 // Reset to default when empty string is provided + config.ParsedSizes.MinSize = 0 } else { - minSize, err := bytesize.Parse(minSizeStr) - if err != nil { + if minSize, err := bytesize.Parse(minSizeStr); err != nil { log.Error().Err(err).Msg("Invalid format for MinSize; unable to parse") } else { config.ParsedSizes.MinSize = minSize } } - // Parse MaxSize maxSizeStr := viper.GetString("sizecheck.maxsize") if maxSizeStr == "" { - config.ParsedSizes.MaxSize = 0 // Reset to default when empty string is provided + config.ParsedSizes.MaxSize = 0 } else { - maxSize, err := bytesize.Parse(maxSizeStr) - if err != nil { + if maxSize, err := bytesize.Parse(maxSizeStr); err != nil { log.Error().Err(err).Msg("Invalid format for MaxSize; unable to parse") } else { config.ParsedSizes.MaxSize = maxSize @@ -85,87 +76,87 @@ func parseSizeCheck() { func watchConfigChanges() { viper.WatchConfig() viper.OnConfigChange(func(e fsnotify.Event) { + handleConfigChange(e) + }) +} - oldConfig := config +func handleConfigChange(e fsnotify.Event) { + oldConfig := config - if err := viper.ReadInConfig(); err != nil { - log.Error().Err(err).Msg("Error reading config") - return - } - if err := viper.Unmarshal(&config); err != nil { - log.Error().Err(err).Msg("Error unmarshalling config") - return - } - - parseSizeCheck() + if err := viper.ReadInConfig(); err != nil { + log.Error().Err(err).Msg("Error reading config") + return + } + if err := viper.Unmarshal(&config); err != nil { + log.Error().Err(err).Msg("Error unmarshalling config") + return + } - logConfigChanges(oldConfig, config) + parseSizeCheck() + logConfigChanges(oldConfig, config) - if oldConfig.Logs.LogLevel != config.Logs.LogLevel { - configureLogger() - } - log.Debug().Msgf("Config file updated: %s", e.Name) - }) + if oldConfig.Logs.LogLevel != config.Logs.LogLevel { + configureLogger() + } + log.Debug().Msgf("Config file updated: %s", e.Name) } func logConfigChanges(oldConfig, newConfig Config) { - - if oldConfig.IndexerKeys.REDKey != newConfig.IndexerKeys.REDKey { // IndexerKeys + if oldConfig.Server.Host != newConfig.Server.Host { + log.Debug().Msgf("Server host changed from %s to %s", oldConfig.Server.Host, newConfig.Server.Host) + } + if oldConfig.IndexerKeys.REDKey != newConfig.IndexerKeys.REDKey { log.Debug().Msg("red_apikey changed") } if oldConfig.IndexerKeys.OPSKey != newConfig.IndexerKeys.OPSKey { log.Debug().Msg("ops_apikey changed") } - if oldConfig.UserIDs.REDUserID != newConfig.UserIDs.REDUserID { // UserIDs + if oldConfig.UserIDs.REDUserID != newConfig.UserIDs.REDUserID { log.Debug().Msgf("REDUserID changed from %d to %d", oldConfig.UserIDs.REDUserID, newConfig.UserIDs.REDUserID) } if oldConfig.UserIDs.OPSUserID != newConfig.UserIDs.OPSUserID { log.Debug().Msgf("OPSUserID changed from %d to %d", oldConfig.UserIDs.OPSUserID, newConfig.UserIDs.OPSUserID) } - if oldConfig.Ratio.MinRatio != newConfig.Ratio.MinRatio { // Ratio + if oldConfig.Ratio.MinRatio != newConfig.Ratio.MinRatio { log.Debug().Msgf("MinRatio changed from %f to %f", oldConfig.Ratio.MinRatio, newConfig.Ratio.MinRatio) } - oldMinSize, _ := bytesize.Parse(oldConfig.SizeCheck.MinSize) - newMinSize, _ := bytesize.Parse(newConfig.SizeCheck.MinSize) - if oldMinSize != newMinSize { // SizeCheck - log.Debug().Msgf("MinSize changed from %s to %s", oldConfig.SizeCheck.MinSize, newConfig.SizeCheck.MinSize) + if oldConfig.ParsedSizes.MinSize != newConfig.ParsedSizes.MinSize { + log.Debug().Msgf("MinSize changed from %s to %s", oldConfig.ParsedSizes.MinSize, newConfig.ParsedSizes.MinSize) } - oldMaxSize, _ := bytesize.Parse(oldConfig.SizeCheck.MaxSize) - newMaxSize, _ := bytesize.Parse(newConfig.SizeCheck.MaxSize) - if oldMaxSize != newMaxSize { // SizeCheck - log.Debug().Msgf("MaxSize changed from %s to %s", oldConfig.SizeCheck.MaxSize, newConfig.SizeCheck.MaxSize) + if oldConfig.ParsedSizes.MaxSize != newConfig.ParsedSizes.MaxSize { + log.Debug().Msgf("MaxSize changed from %s to %s", oldConfig.ParsedSizes.MaxSize, newConfig.ParsedSizes.MaxSize) } - if oldConfig.Uploaders.Uploaders != newConfig.Uploaders.Uploaders { // Uploaders + if oldConfig.Uploaders.Uploaders != newConfig.Uploaders.Uploaders { log.Debug().Msgf("Uploaders changed from %s to %s", oldConfig.Uploaders.Uploaders, newConfig.Uploaders.Uploaders) } - if oldConfig.Uploaders.Mode != newConfig.Uploaders.Mode { // Uploaders + if oldConfig.Uploaders.Mode != newConfig.Uploaders.Mode { log.Debug().Msgf("Uploader mode changed from %s to %s", oldConfig.Uploaders.Mode, newConfig.Uploaders.Mode) } - if oldConfig.Logs.LogLevel != newConfig.Logs.LogLevel { // Logs + if oldConfig.Logs.LogLevel != newConfig.Logs.LogLevel { log.Debug().Msgf("Log level changed from %s to %s", oldConfig.Logs.LogLevel, newConfig.Logs.LogLevel) } - if oldConfig.Logs.LogToFile != newConfig.Logs.LogToFile { // Logs + if oldConfig.Logs.LogToFile != newConfig.Logs.LogToFile { log.Debug().Msgf("LogToFile changed from %t to %t", oldConfig.Logs.LogToFile, newConfig.Logs.LogToFile) } - if oldConfig.Logs.LogFilePath != newConfig.Logs.LogFilePath { // Logs + if oldConfig.Logs.LogFilePath != newConfig.Logs.LogFilePath { log.Debug().Msgf("LogFilePath changed from %s to %s", oldConfig.Logs.LogFilePath, newConfig.Logs.LogFilePath) } - if oldConfig.Logs.MaxSize != newConfig.Logs.MaxSize { // Logs + if oldConfig.Logs.MaxSize != newConfig.Logs.MaxSize { log.Debug().Msgf("Logs MaxSize changed from %d to %d", oldConfig.Logs.MaxSize, newConfig.Logs.MaxSize) } - if oldConfig.Logs.MaxBackups != newConfig.Logs.MaxBackups { // Logs + if oldConfig.Logs.MaxBackups != newConfig.Logs.MaxBackups { log.Debug().Msgf("Logs MaxBackups changed from %d to %d", oldConfig.Logs.MaxBackups, newConfig.Logs.MaxBackups) } - if oldConfig.Logs.MaxAge != newConfig.Logs.MaxAge { // Logs + if oldConfig.Logs.MaxAge != newConfig.Logs.MaxAge { log.Debug().Msgf("Logs MaxAge changed from %d to %d", oldConfig.Logs.MaxAge, newConfig.Logs.MaxAge) } - if oldConfig.Logs.Compress != newConfig.Logs.Compress { // Logs + if oldConfig.Logs.Compress != newConfig.Logs.Compress { log.Debug().Msgf("Logs Compress changed from %t to %t", oldConfig.Logs.Compress, newConfig.Logs.Compress) } } @@ -173,10 +164,13 @@ func logConfigChanges(oldConfig, newConfig Config) { func ValidateConfig() error { var validationErrors []string - if !viper.IsSet("authorization.api_token") || viper.GetString("authorization.api_token") == "" { - validationErrors = append(validationErrors, "Authorization API Token is required") + apiToken := viper.GetString("authorization.api_token") + if envToken, exists := os.LookupEnv("REDACTEDHOOK__API_TOKEN"); exists { + apiToken = envToken + } + if apiToken == "" { + validationErrors = append(validationErrors, "Authorization API Token is required.") } - if viper.IsSet("indexer_keys.red_apikey") && viper.GetString("indexer_keys.red_apikey") == "" { validationErrors = append(validationErrors, "Indexer REDKey should not be empty") } @@ -185,38 +179,6 @@ func ValidateConfig() error { validationErrors = append(validationErrors, "Indexer OPSKey should not be empty") } - //if viper.IsSet("userid.red_user_id") && viper.GetInt("userid.red_user_id") <= 0 { - // validationErrors = append(validationErrors, "Invalid RED User ID") - //} - - //if viper.IsSet("userid.ops_user_id") && viper.GetInt("userid.ops_user_id") <= 0 { - // validationErrors = append(validationErrors, "Invalid OPS User ID") - //} - - //if viper.IsSet("ratio.minratio") && viper.GetFloat64("ratio.minratio") <= 0 { - // validationErrors = append(validationErrors, "Minimum ratio should be positive") - //} - - //if viper.IsSet("sizecheck.minsize") && viper.GetString("sizecheck.minsize") == "" { - // validationErrors = append(validationErrors, "Invalid minimum size") - //} - - //if viper.IsSet("sizecheck.maxsize") && viper.GetString("sizecheck.maxsize") == "" { - // validationErrors = append(validationErrors, "Invalid maximum size") - //} - - //if viper.IsSet("uploaders.uploaders") && viper.GetString("uploaders.uploaders") == "" { - // validationErrors = append(validationErrors, "Invalid uploader list") - //} - - //if viper.IsSet("uploaders.mode") && viper.GetString("uploaders.mode") == "" { - // validationErrors = append(validationErrors, "Invalid uploader mode set") - //} - - //if viper.IsSet("record_labels.record_labels") && viper.GetString("record_labels.record_labels") == "" { - // validationErrors = append(validationErrors, "Invalid record_labels set") - //} - if !viper.IsSet("logs.loglevel") || viper.GetString("logs.loglevel") == "" { validationErrors = append(validationErrors, "Log level is required") } @@ -245,6 +207,23 @@ func ValidateConfig() error { validationErrors = append(validationErrors, "Compress flag is required") } + host := viper.GetString("server.host") + if envHost, exists := os.LookupEnv("REDACTEDHOOK__HOST"); exists { + host = envHost + } + if host == "" { + validationErrors = append(validationErrors, "Server host is required either in config or as an environment variable.") + } + + port := viper.GetInt("server.port") + if envPort, exists := os.LookupEnv("REDACTEDHOOK__PORT"); exists { + fmt.Sscanf(envPort, "%d", &port) + } + + if port <= 0 { + validationErrors = append(validationErrors, "Server port is required either in config or as a positive integer environment variable.") + } + if len(validationErrors) > 0 { return errors.New(strings.Join(validationErrors, "; ")) } diff --git a/internal/config/config_structs.go b/internal/config/config_structs.go index 5bf87df..c480c07 100644 --- a/internal/config/config_structs.go +++ b/internal/config/config_structs.go @@ -14,6 +14,12 @@ type Config struct { Uploaders Uploaders `mapstructure:"uploaders"` RecordLabels RecordLabels `mapstructure:"record_labels"` Logs Logs `mapstructure:"logs"` + Server Server `mapstructure:"server"` +} + +type Server struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` } type Authorization struct { diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..45ec9fc --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,83 @@ +package config + +import ( + "bytes" + "os" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func setupTestEnv() { + viper.Reset() + os.Clearenv() + viper.SetConfigType("toml") + viper.SetConfigFile("testconfig.toml") + viper.ReadConfig(bytes.NewReader([]byte(` + [authorization] + api_token = "test_token" + + [indexer_keys] + red_apikey = "red_key" + ops_apikey = "ops_key" + + [userid] + red_user_id = 1 + ops_user_id = 2 + + [ratio] + minratio = 0.5 + + [sizecheck] + minsize = "10MB" + maxsize = "1GB" + + [uploaders] + uploaders = "test_uploader" + mode = "whitelist" + + [record_labels] + record_labels = "test_label" + + [logs] + loglevel = "debug" + logtofile = false + logfilepath = "test.log" + maxsize = 5 + maxbackups = 2 + maxage = 7 + compress = false + + [server] + host = "127.0.0.1" + port = 42135 + `))) +} + +func TestValidateConfig(t *testing.T) { + setupTestEnv() + viper.Set("authorization.api_token", "") + + err := ValidateConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "Authorization API Token is required.") + + viper.Set("authorization.api_token", "valid_token") + err = ValidateConfig() + assert.NoError(t, err) +} + +func TestWatchConfigChanges(t *testing.T) { + setupTestEnv() + + // simulate a config file change + viper.Set("server.port", 8080) + err := viper.WriteConfigAs("testconfig_updated.toml") + assert.NoError(t, err) + + InitConfig("testconfig_updated.toml") + assert.Equal(t, 8080, config.Server.Port) + + os.Remove("testconfig_updated.toml") +} diff --git a/internal/config/logger.go b/internal/config/logger.go index 5acccbb..cd8bb2e 100644 --- a/internal/config/logger.go +++ b/internal/config/logger.go @@ -12,16 +12,13 @@ import ( func configureLogger() { var writers []io.Writer - // Always log to console + // always log to console consoleWriter := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02 15:04:05"} writers = append(writers, consoleWriter) - // If logtofile is true, also log to file if config.Logs.LogToFile { - logFilePath := config.Logs.LogFilePath - if logFilePath == "" && isRunningInDocker() { - logFilePath = "/redactedhook/redactedhook.log" // Use a sensible default in Docker - } + logFilePath := determineLogFilePath() + fileWriter := &lumberjack.Logger{ Filename: logFilePath, MaxSize: config.Logs.MaxSize, // megabytes @@ -36,18 +33,24 @@ func configureLogger() { multiWriter := zerolog.MultiLevelWriter(writers...) log.Logger = zerolog.New(multiWriter).With().Timestamp().Logger() - // Set the log level setLogLevel(config.Logs.LogLevel) } func setLogLevel(level string) { - loglevel, err := zerolog.ParseLevel(level) + logLevel, err := zerolog.ParseLevel(level) if err != nil { - // If the provided log level is invalid, log an error and default to debug level. log.Error().Msgf("Invalid log level '%s', defaulting to 'debug'", level) - loglevel = zerolog.DebugLevel + logLevel = zerolog.DebugLevel } - // Apply the determined log level. - zerolog.SetGlobalLevel(loglevel) + zerolog.SetGlobalLevel(logLevel) +} + +func determineLogFilePath() string { + logFilePath := config.Logs.LogFilePath + if logFilePath == "" && isRunningInDocker() { + // use a sensible default log file path in Docker + logFilePath = "/redactedhook/redactedhook.log" + } + return logFilePath } diff --git a/internal/config/path_utils.go b/internal/config/path_utils.go index b20e308..47aff82 100644 --- a/internal/config/path_utils.go +++ b/internal/config/path_utils.go @@ -8,10 +8,8 @@ import ( ) func isRunningInDocker() bool { - if _, err := os.Stat("/.dockerenv"); err == nil { - return true - } - return false + _, err := os.Stat("/.dockerenv") + return err == nil } func determineConfigFile(configPath string) string { @@ -19,15 +17,13 @@ func determineConfigFile(configPath string) string { return configPath } - configDir := defaultConfigDir + var configDir string if isRunningInDocker() { - // In Docker, default to the mapped volume directory configDir = os.Getenv("XDG_CONFIG_HOME") if configDir == "" { configDir = "/redactedhook" } } else { - // For non-Docker, use the user's home directory with .config/redactedhook/ homeDir, err := os.UserHomeDir() if err != nil { log.Fatal().Err(err).Msg("Failed to get user home directory") @@ -36,11 +32,5 @@ func determineConfigFile(configPath string) string { } configFile := filepath.Join(configDir, defaultConfigFileName) - - //// Ensure the config file exists - //if err := createConfigFileIfNotExist(configFile); err != nil { - // log.Fatal().Err(err).Msg("Failed to create or verify config file") - //} - return configFile }