diff --git a/.gitignore b/.gitignore index 9b87080..f7f42d9 100644 --- a/.gitignore +++ b/.gitignore @@ -24,9 +24,9 @@ _testmain.go *.prof # User specific -/conf/gunfish.toml +/config/gunfish.toml /gunfish -/shotgun +/gunfish-cli +/apnsmock /test/server* -/h2o_access.log /pkg diff --git a/.travis.yml b/.travis.yml index a82d984..c34d6ed 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,24 +7,18 @@ notifications: env: global: - - H2O_VERSION=1.6.3 - PROJECT_ROOT=$HOME/gopath/src/github.com/kayac/Gunfish go: - - 1.9.x - 1.10.3 - tip addons: apt: - sources: - - kalakris-cmake packages: - - cmake - curl install: - - cd $PROJECT_ROOT && ./test/scripts/build_h2o.sh - cd $PROJECT_ROOT && make get-dep-on-ci && make get-deps - go get github.com/mitchellh/gox @@ -44,7 +38,3 @@ deploy: on: tags: true go: 1.10.3 - -cache: - directories: - - $PROJECT_ROOT/h2o-$H2O_VERSION diff --git a/Makefile b/Makefile index 793f76b..2c09842 100644 --- a/Makefile +++ b/Makefile @@ -31,11 +31,9 @@ gen-cert: test/scripts/gen_test_cert.sh test: gen-cert - nohup h2o -c conf/h2o/h2o.conf > h2o_access.log & - go test -v ./apns || ( pkill h2o && exit 1 ) - go test -v ./fcm || ( pkill h2o && exit 1 ) - go test -v . || ( pkill h2o && exit 1 ) - pkill h2o + go test -v ./apns + go test -v ./fcm + go test -v . clean: rm -f cmd/gunfish/gunfish @@ -44,3 +42,6 @@ clean: build: go build -gcflags="-trimpath=${HOME}" -ldflags="-w" cmd/gunfish/gunfish.go + +tools/%: + go build -gcflags="-trimpath=${HOME}" -ldflags="-w" test/tools/$*/$*.go diff --git a/README.md b/README.md index 415a7d8..41d565b 100644 --- a/README.md +++ b/README.md @@ -171,7 +171,6 @@ max_connections = 2000 error_hook = "echo -e 'Hello Gunfish at error hook!'" [apns] -skip_insecure = true key_file = "/path/to/server.key" cert_file = "/path/to/server.crt" @@ -186,7 +185,6 @@ worker_num |optional| Number of Gunfish owns http clients. queue_size |optional| Limit number of posted JSON from the developer application. max_request_size |optional| Limit size of Posted JSON array. max_connections |optional| Max connections -skip_insecure |optional| Controls whether a client verifies the server's certificate chain and host name. key_file |required| The key file path. cert_file |required| The cert file path. error_hook |optional| Error hook command. This command runs when Gunfish catches an error response. @@ -291,20 +289,37 @@ InitErrorResponseHandler(CustomYourErrorHandler{hookCmd: "echo 'on error!'"}) You can implement a success custom handler in the same way but a hook command is not executed in the success handler in order not to make cpu resource too tight. ### Test -To do test for Gunfish, you have to install [h2o](https://h2o.examp1e.net/). **h2o** is used as APNS mock server. So, if you want to test or optimize parameters for your application, you need to prepare the envronment that h2o APNs Mock server works. -Moreover, you have to build h2o with **mruby-sleep** mrbgem. +``` +$ make test +``` +The following tools are useful to send requests to gunfish for test the following. +- gunfish-cli (send push notification to Gunfish for test) +- apnsmock (APNs mock server) ``` -$ make test +$ make tools/gunfish-cli +$ make tools/apnsmock +``` + +- send a request example with gunfish-cli +``` +$ ./gunfish-cli -type apns -count 1 -json-file some.json -verbose +$ ./gunfish-cli -type apns -count 1 -token -apns-topic -options key1=val1,key2=val2 -verbose +``` + +- start apnsmock server +``` +$ ./apnsmock -cert-file ./test/server.crt -key-file ./test/server.key -verbose ``` ### Benchmark Gunfish repository includes Lua script for the benchmark. You can use wrk command with `err_and_success.lua` script. ``` -$ h2o -c conf/h2o/h2o.conf & +$ make tools/apnsmock +$ ./apnsmock -cert-file ./test/server.crt -key-file ./test/server.key -verbosea & $ ./gunfish -c test/gunfish_test.toml -E test $ wrk2 -t2 -c20 -s bench/scripts/err_and_success.lua -L -R100 http://localhost:38103 ``` diff --git a/apns/client.go b/apns/client.go index cd893a3..e270ed7 100644 --- a/apns/client.go +++ b/apns/client.go @@ -10,6 +10,7 @@ import ( "net/url" "time" + "github.com/kayac/Gunfish/config" "golang.org/x/net/http2" ) @@ -18,10 +19,28 @@ const ( HTTP2ClientTimeout = time.Second * 10 ) +var ClientTransport = func(cert tls.Certificate) *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } +} + +type authToken struct { + jwt string + issuedAt time.Time +} + // Client is apns client type Client struct { - Host string - client *http.Client + Host string + client *http.Client + authToken authToken + kid string + teamID string + key []byte + useAuthToken bool } // Send sends notifications to apns @@ -90,51 +109,79 @@ func (ac *Client) NewRequest(token string, h *Header, payload Payload) (*http.Re } } + // APNs provider token authenticaton + if ac.useAuthToken { + // If iat of jwt is more than 1 hour ago, returns 403 InvalidProviderToken. + // So, recreate jwt earlier than 1 hour. + if ac.authToken.issuedAt.Add(time.Hour - time.Minute).Before(time.Now()) { + if err := ac.issueToken(); err != nil { + return nil, err + } + } + nreq.Header.Set("Authorization", "bearer "+ac.authToken.jwt) + } + return nreq, err } -// NewConnection establishes a http2 connection -func NewConnection(certFile, keyFile string, secuskip bool) (*http.Client, error) { - certPEMBlock, err := ioutil.ReadFile(certFile) - if err != nil { - return nil, err - } +func (ac *Client) issueToken() error { + var err error + now := time.Now() - keyPEMBlock, err := ioutil.ReadFile(keyFile) + ac.authToken.jwt, err = CreateJWT(ac.key, ac.kid, ac.teamID, now) if err != nil { - return nil, err + return err } + ac.authToken.issuedAt = now + return nil +} - cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) +func NewClient(conf config.SectionApns) (*Client, error) { + useAuthToken := conf.Kid != "" && conf.TeamID != "" + tr := &http.Transport{} + if !useAuthToken { + certPEMBlock, err := ioutil.ReadFile(conf.CertFile) + if err != nil { + return nil, err + } - if err != nil { - return nil, err - } + keyPEMBlock, err := ioutil.ReadFile(conf.KeyFile) + if err != nil { + return nil, err + } - tr := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: secuskip, - Certificates: []tls.Certificate{cert}, - }, + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return nil, err + } + tr = ClientTransport(cert) } if err := http2.ConfigureTransport(tr); err != nil { return nil, err } - return &http.Client{ - Timeout: HTTP2ClientTimeout, - Transport: tr, - }, nil -} - -func NewClient(host, cert, key string, skipInsecure bool) (*Client, error) { - c, err := NewConnection(cert, key, skipInsecure) + key, err := ioutil.ReadFile(conf.KeyFile) if err != nil { return nil, err } - return &Client{ - Host: host, - client: c, - }, nil + + client := &Client{ + Host: conf.Host, + client: &http.Client{ + Timeout: HTTP2ClientTimeout, + Transport: tr, + }, + kid: conf.Kid, + teamID: conf.TeamID, + key: key, + useAuthToken: useAuthToken, + } + if client.useAuthToken { + if err := client.issueToken(); err != nil { + return nil, err + } + } + + return client, nil } diff --git a/apns/jwt.go b/apns/jwt.go new file mode 100644 index 0000000..3c1bc79 --- /dev/null +++ b/apns/jwt.go @@ -0,0 +1,112 @@ +package apns + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/json" + "encoding/pem" + "io" + "math/big" + "time" +) + +// https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html#//apple_ref/doc/uid/TP40008194-CH11-SW1 + +const jwtDefaultGrowSize = 256 + +type jwtHeader struct { + Alg string `json:"alg"` + Kid string `json:"kid"` +} + +type jwtClaim struct { + Iss string `json:"iss"` + Iat int64 `json:"iat"` +} + +type ecdsaSignature struct { + R, S *big.Int +} + +func CreateJWT(key []byte, kid string, teamID string, now time.Time) (string, error) { + var b bytes.Buffer + b.Grow(jwtDefaultGrowSize) + + header := jwtHeader{ + Alg: "ES256", + Kid: kid, + } + headerJSON, err := json.Marshal(&header) + if err != nil { + return "", err + } + if err := writeAsBase64(&b, headerJSON); err != nil { + return "", err + } + b.WriteByte(byte('.')) + + claim := jwtClaim{ + Iss: teamID, + Iat: now.Unix(), + } + claimJSON, err := json.Marshal(&claim) + if err != nil { + return "", err + } + if err := writeAsBase64(&b, claimJSON); err != nil { + return "", err + } + + sig, err := createSignature(b.Bytes(), key) + if err != nil { + return "", err + } + b.WriteByte(byte('.')) + + if err := writeAsBase64(&b, sig); err != nil { + return "", err + } + + return b.String(), nil +} + +func writeAsBase64(w io.Writer, byt []byte) error { + enc := base64.NewEncoder(base64.RawURLEncoding, w) + defer enc.Close() + + if _, err := enc.Write(byt); err != nil { + return err + } + return nil +} + +func createSignature(payload []byte, key []byte) ([]byte, error) { + h := crypto.SHA256.New() + if _, err := h.Write(payload); err != nil { + return nil, err + } + msg := h.Sum(nil) + + block, _ := pem.Decode(key) + p8key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + + r, s, err := ecdsa.Sign(rand.Reader, p8key.(*ecdsa.PrivateKey), msg) + if err != nil { + return nil, err + } + + sig, err := asn1.Marshal(ecdsaSignature{r, s}) + if err != nil { + return nil, err + } + + return sig, nil +} diff --git a/apns/mock_server.go b/apns/mock_server.go index 21da58e..bb08212 100644 --- a/apns/mock_server.go +++ b/apns/mock_server.go @@ -96,18 +96,6 @@ func StartAPNSMockServer(cert, key string) { log.Fatal(s.Serve(tlsListener)) } -// StopAPNSServer stops APNS Mock server -func StopAPNSServer(cert, key string, insecure bool) error { - client, err := NewConnection(cert, key, insecure) - if err != nil { - return err - } - - client.Get("/stop") - - return nil -} - func createErrorResponse(ermsg ErrorResponseCode, status int) string { var er ErrorResponse if status == http.StatusGone { diff --git a/cmd/gunfish/gunfish.go b/cmd/gunfish/gunfish.go index 0e4400b..fb03c07 100644 --- a/cmd/gunfish/gunfish.go +++ b/cmd/gunfish/gunfish.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "flag" "fmt" "net" @@ -11,6 +12,8 @@ import ( "strconv" "github.com/kayac/Gunfish" + "github.com/kayac/Gunfish/apns" + "github.com/kayac/Gunfish/config" "github.com/sirupsen/logrus" ) @@ -18,7 +21,7 @@ var version string func main() { var ( - config string + confPath string environment string logFormat string port int @@ -27,8 +30,8 @@ func main() { logLevel string ) - flag.StringVar(&config, "config", "/etc/gunfish/config.toml", "specify config file.") - flag.StringVar(&config, "c", "/etc/gunfish/config.toml", "specify config file.") + flag.StringVar(&confPath, "config", "/etc/gunfish/config.toml", "specify config file.") + flag.StringVar(&confPath, "c", "/etc/gunfish/config.toml", "specify config file.") flag.StringVar(&environment, "environment", "production", "APNS environment. (production, development, or test)") flag.StringVar(&environment, "E", "production", "APNS environment. (production, development, or test)") flag.IntVar(&port, "port", 0, "Gunfish port number (range 1024-65535).") @@ -50,7 +53,7 @@ func main() { initLogrus(logFormat, logLevel) - c, err := gunfish.LoadConfig(config) + c, err := config.LoadConfig(confPath) if err != nil { logrus.Error(err) os.Exit(1) @@ -69,6 +72,14 @@ func main() { env = gunfish.Development case "test": env = gunfish.Test + apns.ClientTransport = func(cert tls.Certificate) *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + }, + } + } default: logrus.Error("Unknown environment: %s. Please look at help.", environment) os.Exit(1) diff --git a/conf/h2o/h2o.conf b/conf/h2o/h2o.conf deleted file mode 100644 index 07358c9..0000000 --- a/conf/h2o/h2o.conf +++ /dev/null @@ -1,16 +0,0 @@ -# to find out the configuration commands, run: h2o --help - -num-threads: 4 -http2-max-concurrent-requests-per-connection: 256 -max-connections: 10240 -listen: - port: 2195 - ssl: - certificate-file: test/server.crt - key-file: test/server.key -hosts: - "127.0.0.1.xip.io:2195": - paths: - /: - mruby.handler-file: test/h2o/apnsmock.rb - access-log: /dev/stdout diff --git a/config.go b/config/config.go similarity index 58% rename from config.go rename to config/config.go index fe858f0..00e9181 100644 --- a/config.go +++ b/config/config.go @@ -1,4 +1,4 @@ -package gunfish +package config import ( "crypto/tls" @@ -7,7 +7,26 @@ import ( "time" goconf "github.com/kayac/go-config" - "github.com/sirupsen/logrus" +) + +// Limit values +const ( + MaxWorkerNum = 119 // Maximum of worker number + MinWorkerNum = 1 // Minimum of worker number + MaxQueueSize = 40960 // Maximum queue size. + MinQueueSize = 128 // Minimum Queue size. + MaxRequestSize = 5000 // Maximum of requset count. + MinRequestSize = 1 // Minimum of request size. + LimitApnsTokenByteSize = 100 // Payload byte size. +) + +const ( + // Default array size of posted data. If not configures at file, this value is set. + DefaultRequestQueueSize = 2000 + // Default port number of provider server + DefaultPort = 8003 + // Default supervisor's queue size. If not configures at file, this value is set. + DefaultQueueSize = 1000 ) // Config is the configure of an APNS provider server @@ -31,17 +50,18 @@ type SectionProvider struct { // SectionApns is the configure which is loaded from gunfish.toml type SectionApns struct { Host string - SkipInsecure bool `toml:"skip_insecure"` CertFile string `toml:"cert_file"` KeyFile string `toml:"key_file"` + Kid string `toml:"kid"` + TeamID string `toml:"team_id"` CertificateNotAfter time.Time - enabled bool + Enabled bool } // SectionFCM is the configuration of fcm type SectionFCM struct { APIKey string `toml:"api_key"` - enabled bool + Enabled bool } // DefaultLoadConfig loads default /etc/gunfish.toml @@ -54,7 +74,6 @@ func LoadConfig(fn string) (Config, error) { var config Config if err := goconf.LoadWithEnvTOML(&config, fn); err != nil { - LogWithFields(logrus.Fields{"type": "load_config"}).Warnf("%v %s %s", config, err, fn) return config, err } @@ -73,7 +92,6 @@ func LoadConfig(fn string) (Config, error) { // validates config parameters if err := (&config).validateConfig(); err != nil { - LogWithFields(logrus.Fields{"type": "load_config"}).Error(err) return config, err } @@ -81,14 +99,17 @@ func LoadConfig(fn string) (Config, error) { } func (c *Config) validateConfig() error { - if c.Apns.CertFile != "" && c.Apns.KeyFile != "" { - c.Apns.enabled = true - if err := c.validateConfigApns(); err != nil { + if err := c.validateConfigProvider(); err != nil { + return err + } + if (c.Apns.CertFile != "" && c.Apns.KeyFile != "") || (c.Apns.TeamID != "" && c.Apns.Kid != "") { + c.Apns.Enabled = true + if err := c.validateConfigAPNs(); err != nil { return err } } if c.FCM.APIKey != "" { - c.FCM.enabled = true + c.FCM.Enabled = true if err := c.validateConfigFCM(); err != nil { return err } @@ -96,31 +117,7 @@ func (c *Config) validateConfig() error { return nil } -func (c *Config) validateConfigFCM() error { - return nil -} - -func (c *Config) validateConfigApns() error { - // check certificate files and expiration - cert, err := tls.LoadX509KeyPair(c.Apns.CertFile, c.Apns.KeyFile) - if err != nil { - return fmt.Errorf("Invalid certificate pair for APNS: %s", err) - } - now := time.Now() - for _, _ct := range cert.Certificate { - ct, err := x509.ParseCertificate(_ct) - if err != nil { - return fmt.Errorf("Cannot parse X509 certificate") - } - if now.Before(ct.NotBefore) || now.After(ct.NotAfter) { - return fmt.Errorf("Certificate is expired. Subject: %s, NotBefore: %s, NotAfter: %s", ct.Subject, ct.NotBefore, ct.NotAfter) - } - if c.Apns.CertificateNotAfter.IsZero() || c.Apns.CertificateNotAfter.Before(ct.NotAfter) { - // hold minimum not after - c.Apns.CertificateNotAfter = ct.NotAfter - } - } - +func (c *Config) validateConfigProvider() error { if c.Provider.RequestQueueSize < MinRequestSize || c.Provider.RequestQueueSize > MaxRequestSize { return fmt.Errorf("MaxRequestSize was out of available range: %d. (%d-%d)", c.Provider.RequestQueueSize, MinRequestSize, MaxRequestSize) @@ -138,3 +135,32 @@ func (c *Config) validateConfigApns() error { return nil } + +func (c *Config) validateConfigFCM() error { + return nil +} + +func (c *Config) validateConfigAPNs() error { + if c.Apns.CertFile != "" && c.Apns.KeyFile != "" { + // check certificate files and expiration + cert, err := tls.LoadX509KeyPair(c.Apns.CertFile, c.Apns.KeyFile) + if err != nil { + return fmt.Errorf("Invalid certificate pair for APNS: %s", err) + } + now := time.Now() + for _, _ct := range cert.Certificate { + ct, err := x509.ParseCertificate(_ct) + if err != nil { + return fmt.Errorf("Cannot parse X509 certificate") + } + if now.Before(ct.NotBefore) || now.After(ct.NotAfter) { + return fmt.Errorf("Certificate is expired. Subject: %s, NotBefore: %s, NotAfter: %s", ct.Subject, ct.NotBefore, ct.NotAfter) + } + if c.Apns.CertificateNotAfter.IsZero() || c.Apns.CertificateNotAfter.Before(ct.NotAfter) { + // hold minimum not after + c.Apns.CertificateNotAfter = ct.NotAfter + } + } + } + return nil +} diff --git a/config_test.go b/config/config_test.go similarity index 83% rename from config_test.go rename to config/config_test.go index 3a6eeed..a0fe8f5 100644 --- a/config_test.go +++ b/config/config_test.go @@ -1,4 +1,4 @@ -package gunfish +package config import ( "os" @@ -10,7 +10,7 @@ func TestLoadTomlConfigFile(t *testing.T) { t.Error(err) } - c, err := LoadConfig("./test/gunfish_test.toml") + c, err := LoadConfig("../test/gunfish_test.toml") if err != nil { t.Error(err) } diff --git a/conf/gunfish.toml.example b/config/gunfish.toml.example similarity index 100% rename from conf/gunfish.toml.example rename to config/gunfish.toml.example diff --git a/const.go b/const.go index 69a88af..50ef4c7 100644 --- a/const.go +++ b/const.go @@ -4,17 +4,6 @@ import ( "time" ) -// Limit values -const ( - MaxWorkerNum = 119 // Maximum of worker number - MinWorkerNum = 1 // Minimum of worker number - MaxQueueSize = 40960 // Maximum queue size. - MinQueueSize = 128 // Minimum Queue size. - MaxRequestSize = 5000 // Maximum of requset count. - MinRequestSize = 1 // Minimum of request size. - LimitApnsTokenByteSize = 100 // Payload byte size. -) - // Default values const ( // SendRetryCount is the threashold which is resend count. @@ -27,12 +16,6 @@ const ( SenderNum = 20 RequestPerSec = 2000 - // Default array size of posted data. If not configures at file, this value is set. - DefaultRequestQueueSize = 2000 - // Default port number of provider server - DefaultPort = 8003 - // Default supervisor's queue size. If not configures at file, this value is set. - DefaultQueueSize = 1000 // About the average time of response from apns. That value is not accurate // because that is defined heuristically in Japan. AverageResponseTime = time.Millisecond * 150 diff --git a/gunfish_test.go b/gunfish_test.go index cd65ddd..cbfb122 100644 --- a/gunfish_test.go +++ b/gunfish_test.go @@ -20,7 +20,7 @@ func BenchmarkGunfish(b *testing.B) { b.StopTimer() go func() { - StartServer(config, Test) + StartServer(conf, Test) }() time.Sleep(time.Second * 1) @@ -44,7 +44,7 @@ func BenchmarkGunfish(b *testing.B) { } func do(jsons *bytes.Buffer) error { - u, err := url.Parse(fmt.Sprintf("http://localhost:%d/push/apns", config.Provider.Port)) + u, err := url.Parse(fmt.Sprintf("http://localhost:%d/push/apns", conf.Provider.Port)) if err != nil { return err } diff --git a/mock/apns_server.go b/mock/apns_server.go new file mode 100644 index 0000000..29b8cc8 --- /dev/null +++ b/mock/apns_server.go @@ -0,0 +1,96 @@ +package mock + +import ( + "encoding/json" + "fmt" + "io" + "log" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/kayac/Gunfish/apns" +) + +const ( + ApplicationJSON = "application/json" + LimitApnsTokenByteSize = 100 // Payload byte size. +) + +// StartAPNSMockServer starts HTTP/2 server for mock +func APNsMockServer(verbose bool) *http.ServeMux { + mux := http.NewServeMux() + + mux.HandleFunc("/3/device/", func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + defer func() { + if verbose { + log.Printf("reqtime:%f proto:%s method:%s path:%s host:%s", reqtime(start), r.Proto, r.Method, r.URL.Path, r.RemoteAddr) + } + }() + + // sets the response time from apns server + time.Sleep(time.Millisecond*200 + time.Millisecond*(time.Duration(rand.Int63n(200)-100))) + + // only allow path which pattern is '/3/device/:token' + splitPath := strings.Split(r.URL.Path, "/") + if len(splitPath) != 4 { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "404 Not found") + return + } + + w.Header().Set("Content-Type", ApplicationJSON) + + token := splitPath[len(splitPath)-1] + if len(([]byte(token))) > LimitApnsTokenByteSize || token == "baddevicetoken" { + w.Header().Set("apns-id", "apns-id") + w.WriteHeader(http.StatusBadRequest) + createErrorResponse(w, apns.BadDeviceToken, http.StatusBadRequest) + } else if token == "missingtopic" { + w.WriteHeader(http.StatusBadRequest) + createErrorResponse(w, apns.MissingTopic, http.StatusBadRequest) + } else if token == "unregistered" { + // If the value in the :status header is 410, the value of this key is + // the last time at which APNs confirmed that the device token was + // no longer valid for the topic. + // + // Stop pushing notifications until the device registers a token with + // a later timestamp with your provider. + w.WriteHeader(http.StatusGone) + createErrorResponse(w, apns.Unregistered, http.StatusGone) + } else if token == "expiredprovidertoken" { + w.WriteHeader(http.StatusForbidden) + createErrorResponse(w, apns.ExpiredProviderToken, http.StatusForbidden) + } else { + w.Header().Set("apns-id", "apns-id") + w.WriteHeader(http.StatusOK) + } + + return + }) + + return mux +} + +func createErrorResponse(w io.Writer, ermsg apns.ErrorResponseCode, status int) error { + enc := json.NewEncoder(w) + var er apns.ErrorResponse + if status == http.StatusGone { + er = apns.ErrorResponse{ + Reason: ermsg.String(), + Timestamp: time.Now().Unix(), + } + } else { + er = apns.ErrorResponse{ + Reason: ermsg.String(), + } + } + return enc.Encode(er) +} + +func reqtime(start time.Time) float64 { + diff := time.Now().Sub(start) + return diff.Seconds() +} diff --git a/server.go b/server.go index 87a3982..f15e263 100644 --- a/server.go +++ b/server.go @@ -15,6 +15,7 @@ import ( "github.com/fukata/golang-stats-api-handler" "github.com/kayac/Gunfish/apns" + "github.com/kayac/Gunfish/config" "github.com/kayac/Gunfish/fcm" "github.com/lestrrat/go-server-starter/listener" "github.com/shogo82148/go-gracedown" @@ -51,7 +52,7 @@ func (rh DefaultResponseHandler) HookCmd() string { } // StartServer starts an apns provider server on http. -func StartServer(conf Config, env Environment) { +func StartServer(conf config.Config, env Environment) { // Initialize DefaultResponseHandler if response handlers are not defined. if successResponseHandler == nil { InitSuccessResponseHandler(DefaultResponseHandler{}) @@ -137,13 +138,13 @@ func StartServer(conf Config, env Environment) { }).Infof("Starts provider on :%d ...", conf.Provider.Port) mux := http.NewServeMux() - if conf.Apns.enabled { + if conf.Apns.Enabled { LogWithFields(logrus.Fields{ "type": "provider", }).Infof("Enable endpoint /push/apns") mux.HandleFunc("/push/apns", prov.pushAPNsHandler()) } - if conf.FCM.enabled { + if conf.FCM.Enabled { LogWithFields(logrus.Fields{ "type": "provider", }).Infof("Enable endpoint /push/fcm") @@ -342,8 +343,8 @@ func validatePostedData(ps []PostedData) error { return fmt.Errorf("PostedData must not be empty: %v", ps) } - if len(ps) > MaxRequestSize { - return fmt.Errorf("PostedData was too long. Be less than %d: %v", MaxRequestSize, len(ps)) + if len(ps) > config.MaxRequestSize { + return fmt.Errorf("PostedData was too long. Be less than %d: %v", config.MaxRequestSize, len(ps)) } for _, p := range ps { diff --git a/server_test.go b/server_test.go index 6d8f79f..e9a8f19 100644 --- a/server_test.go +++ b/server_test.go @@ -2,28 +2,56 @@ package gunfish import ( "bytes" + "crypto/tls" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" + "os" "reflect" "testing" "time" "github.com/kayac/Gunfish/apns" + "github.com/kayac/Gunfish/config" + "github.com/kayac/Gunfish/mock" "github.com/sirupsen/logrus" + "golang.org/x/net/http2" ) -func init() { - InitErrorResponseHandler(DefaultResponseHandler{hook: `cat `}) - InitSuccessResponseHandler(DefaultResponseHandler{}) - logrus.SetLevel(logrus.WarnLevel) - config.Apns.Host = MockServer +func TestMain(m *testing.M) { + runner := func() int { + apns.ClientTransport = func(cert tls.Certificate) *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + }, + } + } + InitErrorResponseHandler(DefaultResponseHandler{hook: `cat `}) + InitSuccessResponseHandler(DefaultResponseHandler{}) + logrus.SetLevel(logrus.WarnLevel) + + ts := httptest.NewUnstartedServer(mock.APNsMockServer(false)) + if err := http2.ConfigureServer(ts.Config, nil); err != nil { + return 1 + } + ts.TLS = ts.Config.TLSConfig + ts.StartTLS() + conf.Apns.Host = ts.URL + + code := m.Run() + + return code + } + + os.Exit(runner()) } func TestInvalidCertification(t *testing.T) { - c, _ := LoadConfig("./test/gunfish_test.toml") + c, _ := config.LoadConfig("./test/gunfish_test.toml") c.Apns.CertFile = "./test/invalid.crt" c.Apns.KeyFile = "./test/invalid.key" ss, err := StartSupervisor(&c) @@ -33,7 +61,7 @@ func TestInvalidCertification(t *testing.T) { } func TestSuccessToPostJson(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} handler := prov.pushAPNsHandler() @@ -67,7 +95,7 @@ func TestSuccessToPostJson(t *testing.T) { } func TestFailedToPostInvalidJson(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} handler := prov.pushFCMHandler() @@ -91,7 +119,7 @@ func TestFailedToPostInvalidJson(t *testing.T) { } func TestFailedToPostMalformedJson(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} handler := prov.pushAPNsHandler() @@ -133,15 +161,15 @@ func TestFailedToPostMalformedJson(t *testing.T) { } func TestEnqueueTooManyRequest(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} - srvStats = NewStats(config) + srvStats = NewStats(conf) handler := prov.pushAPNsHandler() // When queue stack is full, return 503 var manyNum int - tp := ((config.Provider.RequestQueueSize * int(AverageResponseTime/time.Millisecond)) / 1000) / SenderNum - dif := (RequestPerSec - config.Provider.RequestQueueSize/tp) + tp := ((conf.Provider.RequestQueueSize * int(AverageResponseTime/time.Millisecond)) / 1000) / SenderNum + dif := (RequestPerSec - conf.Provider.RequestQueueSize/tp) if dif > 0 { manyNum = dif * int(FlowRateInterval/time.Second) * 2 } else { @@ -198,12 +226,12 @@ func TestEnqueueTooManyRequest(t *testing.T) { } func TestTooLargeRequest(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} - srvStats = NewStats(config) + srvStats = NewStats(conf) handler := prov.pushAPNsHandler() - jsons := createJSONPostedData(MaxRequestSize + 1) // Too many requests + jsons := createJSONPostedData(config.MaxRequestSize + 1) // Too many requests r, err := newRequest(jsons, "POST", ApplicationJSON) if err != nil { t.Errorf("%s", err) @@ -219,7 +247,7 @@ func TestTooLargeRequest(t *testing.T) { } func TestMethodNotAllowed(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} handler := prov.pushAPNsHandler() @@ -239,7 +267,7 @@ func TestMethodNotAllowed(t *testing.T) { } func TestUnsupportedMediaType(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} handler := prov.pushAPNsHandler() @@ -264,9 +292,9 @@ func TestUnsupportedMediaType(t *testing.T) { } func TestStats(t *testing.T) { - sup, _ := StartSupervisor(&config) + sup, _ := StartSupervisor(&conf) prov := &Provider{sup: sup} - srvStats = NewStats(config) + srvStats = NewStats(conf) pushh := prov.pushAPNsHandler() statsh := prov.statsHandler() diff --git a/stat.go b/stat.go index ad3b693..4391bde 100644 --- a/stat.go +++ b/stat.go @@ -3,6 +3,8 @@ package gunfish import ( "os" "time" + + "github.com/kayac/Gunfish/config" ) // Stats stores metrics @@ -28,7 +30,7 @@ type Stats struct { } // NewStats initialize Stats -func NewStats(conf Config) Stats { +func NewStats(conf config.Config) Stats { return Stats{ Pid: os.Getpid(), StartAt: time.Now().Unix(), @@ -42,6 +44,8 @@ func (st *Stats) GetStats() *Stats { preUptime := st.Uptime st.Uptime = time.Now().Unix() - st.StartAt st.Period = st.Uptime - preUptime - st.CertificateExpireUntil = int64(st.CertificateNotAfter.Sub(time.Now()).Seconds()) + if !st.CertificateNotAfter.IsZero() { + st.CertificateExpireUntil = int64(st.CertificateNotAfter.Sub(time.Now()).Seconds()) + } return st } diff --git a/supervisor.go b/supervisor.go index a756f2c..fbde308 100644 --- a/supervisor.go +++ b/supervisor.go @@ -2,6 +2,7 @@ package gunfish import ( "bytes" + "errors" "fmt" "io" "net/http" @@ -13,6 +14,7 @@ import ( "time" "github.com/kayac/Gunfish/apns" + "github.com/kayac/Gunfish/config" "github.com/kayac/Gunfish/fcm" "github.com/satori/go.uuid" "github.com/sirupsen/logrus" @@ -78,7 +80,7 @@ func (s *Supervisor) EnqueueClientRequest(reqs *[]Request) error { } // StartSupervisor starts supervisor -func StartSupervisor(conf *Config) (Supervisor, error) { +func StartSupervisor(conf *config.Config) (Supervisor, error) { // Calculates each worker queue size to accept requests with a given parameter of requests per sec as flow rate. var wqSize int tp := ((conf.Provider.RequestQueueSize * int(AverageResponseTime/time.Millisecond)) / 1000) / SenderNum @@ -157,8 +159,8 @@ func StartSupervisor(conf *Config) (Supervisor, error) { ac *apns.Client fc *fcm.Client ) - if conf.Apns.enabled { - ac, err = apns.NewClient(conf.Apns.Host, conf.Apns.CertFile, conf.Apns.KeyFile, conf.Apns.SkipInsecure) + if conf.Apns.Enabled { + ac, err = apns.NewClient(conf.Apns) if err != nil { LogWithFields(logrus.Fields{ "type": "supervisor", @@ -166,7 +168,7 @@ func StartSupervisor(conf *Config) (Supervisor, error) { break } } - if conf.FCM.enabled { + if conf.FCM.Enabled { fc, err = fcm.NewClient(conf.FCM.APIKey, nil, fcm.ClientTimeout) if err != nil { LogWithFields(logrus.Fields{ @@ -240,7 +242,7 @@ func (s *Supervisor) Shutdown() { }).Infoln("Stoped supervisor.") } -func (s *Supervisor) spawnWorker(w Worker, conf *Config) { +func (s *Supervisor) spawnWorker(w Worker, conf *config.Config) { atomic.AddInt64(&(srvStats.Workers), 1) defer func() { atomic.AddInt64(&(srvStats.Workers), -1) @@ -333,24 +335,7 @@ func handleAPNsResponse(resp SenderResponse, retryq chan<- Request, cmdq chan Co onResponse(result, errorResponseHandler.HookCmd(), cmdq) } else { // if 'result' is nil, HTTP connection error with APNS. - LogWithFields(logf).Warnf("response is nil. reason: %s", resp.Err.Error()) - if req.Tries < SendRetryCount { - req.Tries++ - atomic.AddInt64(&(srvStats.RetryCount), 1) - logf["resend_cnt"] = req.Tries - - select { - case retryq <- req: - LogWithFields(logf). - Debugf("Retry to enqueue into retryq because of http connection error with APNS.") - default: - LogWithFields(logf). - Warnf("Supervisor retry queue is full.") - } - } else { - LogWithFields(logf). - Warnf("Retry count is over than %d. Could not deliver notification.", SendRetryCount) - } + retry(retryq, req, errors.New("http connection error between APNs"), logf) } } else { atomic.AddInt64(&(srvStats.SentCount), 1) @@ -361,6 +346,12 @@ func handleAPNsResponse(resp SenderResponse, retryq chan<- Request, cmdq chan Co } if err := result.Err(); err != nil { atomic.AddInt64(&(srvStats.ErrCount), 1) + + // retry when provider auhentication token is expired + if err.Error() == apns.ExpiredProviderToken.String() { + retry(retryq, req, err, logf) + } + onResponse(result, errorResponseHandler.HookCmd(), cmdq) LogWithFields(logf).Errorf("%s", err) } else { @@ -571,3 +562,23 @@ func invokePipe(hook string, src io.Reader) ([]byte, error) { err = cmd.Run() return b.Bytes(), err } + +func retry(retryq chan<- Request, req Request, err error, logf logrus.Fields) { + if req.Tries < SendRetryCount { + req.Tries++ + atomic.AddInt64(&(srvStats.RetryCount), 1) + logf["resend_cnt"] = req.Tries + + select { + case retryq <- req: + LogWithFields(logf). + Debugf("%s: Retry to enqueue into retryq.", err.Error()) + default: + LogWithFields(logf). + Warnf("Supervisor retry queue is full.") + } + } else { + LogWithFields(logf). + Warnf("Retry count is over than %d. Could not deliver notification.", SendRetryCount) + } +} diff --git a/supervisor_test.go b/supervisor_test.go index b753180..ecaafce 100644 --- a/supervisor_test.go +++ b/supervisor_test.go @@ -9,17 +9,19 @@ import ( "time" "github.com/kayac/Gunfish/apns" + "github.com/kayac/Gunfish/config" "github.com/sirupsen/logrus" ) var ( - config, _ = LoadConfig("./test/gunfish_test.toml") + conf, _ = config.LoadConfig("./test/gunfish_test.toml") ) type TestResponseHandler struct { scoreboard map[string]*int wg *sync.WaitGroup hook string + mu sync.Mutex } func (tr *TestResponseHandler) Done(token string) { @@ -27,6 +29,8 @@ func (tr *TestResponseHandler) Done(token string) { } func (tr *TestResponseHandler) Countup(name string) { + tr.mu.Lock() + defer tr.mu.Unlock() *(tr.scoreboard[name])++ } @@ -34,15 +38,7 @@ func (tr TestResponseHandler) OnResponse(result Result) { tr.wg.Add(1) if err := result.Err(); err != nil { logrus.Warnf(err.Error()) - if err.Error() == apns.MissingTopic.String() { - tr.Countup(apns.MissingTopic.String()) - } - if err.Error() == apns.BadDeviceToken.String() { - tr.Countup(apns.BadDeviceToken.String()) - } - if err.Error() == apns.Unregistered.String() { - tr.Countup(apns.Unregistered.String()) - } + tr.Countup(err.Error()) } else { tr.Countup("success") } @@ -55,11 +51,11 @@ func (tr TestResponseHandler) HookCmd() string { func init() { logrus.SetLevel(logrus.WarnLevel) - config.Apns.Host = MockServer + conf.Apns.Host = MockServer } func TestStartAndStopSupervisor(t *testing.T) { - sup, err := StartSupervisor(&config) + sup, err := StartSupervisor(&conf) if err != nil { t.Errorf("cannot start supvisor: %s", err.Error()) } @@ -82,8 +78,15 @@ func TestStartAndStopSupervisor(t *testing.T) { func TestEnqueuRequestToSupervisor(t *testing.T) { // Prepare wg := sync.WaitGroup{} - score := make(map[string]*int, 4) - for _, v := range []string{apns.MissingTopic.String(), apns.BadDeviceToken.String(), apns.Unregistered.String(), "success"} { + score := make(map[string]*int, 5) + boardList := []string{ + apns.MissingTopic.String(), + apns.BadDeviceToken.String(), + apns.Unregistered.String(), + apns.ExpiredProviderToken.String(), + "success", + } + for _, v := range boardList { x := 0 score[v] = &x } @@ -91,51 +94,82 @@ func TestEnqueuRequestToSupervisor(t *testing.T) { etr := TestResponseHandler{ wg: &wg, scoreboard: score, - hook: config.Provider.ErrorHook, + hook: conf.Provider.ErrorHook, + mu: sync.Mutex{}, } str := TestResponseHandler{ wg: &wg, scoreboard: score, + mu: sync.Mutex{}, } InitErrorResponseHandler(etr) InitSuccessResponseHandler(str) - sup, err := StartSupervisor(&config) + sup, err := StartSupervisor(&conf) if err != nil { t.Errorf("cannot start supervisor: %s", err.Error()) } + defer sup.Shutdown() // test success requests reqs := repeatRequestData("1122334455667788112233445566778811223344556677881122334455667788", 10) for range []int{0, 1, 2, 3, 4, 5, 6} { sup.EnqueueClientRequest(&reqs) } + time.Sleep(time.Millisecond * 500) + wg.Wait() + if g, w := *(score["success"]), 70; g != w { + t.Errorf("not match success count: got %d want %d", g, w) + } // test error requests - mreqs := repeatRequestData("missingtopic", 1) - sup.EnqueueClientRequest(&mreqs) - - ureqs := repeatRequestData("unregistered", 1) - sup.EnqueueClientRequest(&ureqs) - - breqs := repeatRequestData("baddevicetoken", 1) - sup.EnqueueClientRequest(&breqs) + testTable := []struct { + errToken string + num int + msleep time.Duration + errCode apns.ErrorResponseCode + expect int + }{ + { + errToken: "missingtopic", + num: 1, + msleep: 300, + errCode: apns.MissingTopic, + expect: 1, + }, + { + errToken: "unregistered", + num: 1, + msleep: 300, + errCode: apns.Unregistered, + expect: 1, + }, + { + errToken: "baddevicetoken", + num: 1, + msleep: 300, + errCode: apns.BadDeviceToken, + expect: 1, + }, + { + errToken: "expiredprovidertoken", + num: 1, + msleep: 5000, + errCode: apns.ExpiredProviderToken, + expect: 1 * SendRetryCount, + }, + } - time.Sleep(time.Second * 1) - wg.Wait() - sup.Shutdown() + for _, tt := range testTable { + reqs := repeatRequestData(tt.errToken, tt.num) + sup.EnqueueClientRequest(&reqs) + time.Sleep(time.Millisecond * tt.msleep) + wg.Wait() - if *(score[apns.MissingTopic.String()]) != 1 { - t.Errorf("Expected MissingTopic count is 1 but got %d", *(score[apns.MissingTopic.String()])) - } - if *(score[apns.Unregistered.String()]) != 1 { - t.Errorf("Expected Unregistered count is 1 but got %d", *(score[apns.Unregistered.String()])) - } - if *(score[apns.BadDeviceToken.String()]) != 1 { - t.Errorf("Expected BadDeviceToken count is 1 but got %d", *(score[apns.BadDeviceToken.String()])) - } - if *(score["success"]) != 70 { - t.Errorf("Expected success count is 70 but got %d", *(score["success"])) + errReason := tt.errCode.String() + if g, w := *(score[errReason]), tt.expect; g != w { + t.Errorf("not match %s count: got %d want %d", errReason, g, w) + } } } diff --git a/test/gunfish_test.toml b/test/gunfish_test.toml index 3ef0be3..b715ae6 100644 --- a/test/gunfish_test.toml +++ b/test/gunfish_test.toml @@ -7,7 +7,6 @@ max_connections = 2000 error_hook = "{{ env `TEST_GUNFISH_HOOK_CMD` `cat ` }}" [apns] -skip_insecure = true key_file = "./test/server.key" cert_file = "./test/server.crt" sender_num = 50 diff --git a/test/h2o/apnsmock.rb b/test/h2o/apnsmock.rb deleted file mode 100644 index 03fc4ab..0000000 --- a/test/h2o/apnsmock.rb +++ /dev/null @@ -1,40 +0,0 @@ -LIMIT_APNS_TOKEN_BYTE_SIZE = 100 - -class ApnsMock - def call(env) - now = Time.now - if /\/3\/device\/(.*)?$/.match(env["PATH_INFO"]) - Sleep::usleep( 750 + (rand(1500) - 750 )) - token = $1 - if token.length > LIMIT_APNS_TOKEN_BYTE_SIZE || token == "baddevicetoken" - print_time now - return [400, { - "content-type" => "application/json" - }, ['{"reason":"BadDeviceToken"}']] - elsif token == "unregistered" - print_time now - return [410, { - "content-type" => "application/json" - }, ['{"reason":"Unregistered","timestamp":1454402113}']] - elsif token == "missingtopic" - print_time now - return [400, { - "content-type" => "application/json" - }, ['{"reason":"MissingTopic"}']] - end - print_time now - return [200, {}, {}] - else - print_time now - return [404, { - "content-type" => "application/json" - }, ['{"reason":"not found"}']] - end - end -end - -def print_time(now) - p Time.now - now -end - -ApnsMock.new diff --git a/test/scripts/build_h2o.sh b/test/scripts/build_h2o.sh deleted file mode 100755 index 6da037e..0000000 --- a/test/scripts/build_h2o.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -gsed=$(which gsed) -sed=${gsed:-sed} - -set -xeu - -version='1.6.3' - -if [[ ! -f "h2o-$version/h2o" ]] ; then - rm -rf h2o-$version - wget https://github.com/h2o/h2o/archive/v$version.tar.gz - tar xzf v$version.tar.gz -fi -cd h2o-$version - -insert_num=$(grep -n MRuby misc/mruby_config.rb | awk -F':' '{print $1}') -insert_gem="conf.gem :git => 'https://github.com/matsumoto-r/mruby-sleep.git'" -$sed -i "${insert_num} a ${insert_gem}" misc/mruby_config.rb - -cmake -DWITH_BUNDLED_SSL=on . -make -sudo make install diff --git a/test/tools/apnsmock/apnsmock.go b/test/tools/apnsmock/apnsmock.go new file mode 100644 index 0000000..5be8df4 --- /dev/null +++ b/test/tools/apnsmock/apnsmock.go @@ -0,0 +1,30 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + + "github.com/kayac/Gunfish/mock" +) + +func main() { + var ( + port int + keyFile, certFile string + verbose bool + ) + + flag.IntVar(&port, "port", 2195, "apns mock server port") + flag.StringVar(&keyFile, "cert-file", "", "apns mock server key file") + flag.StringVar(&certFile, "key-file", "", "apns mock server cert file") + flag.BoolVar(&verbose, "verbose", false, "verbose flag") + flag.Parse() + + mux := mock.APNsMockServer(verbose) + log.Println("start apnsmock server") + if err := http.ListenAndServeTLS(fmt.Sprintf(":%d", port), keyFile, certFile, mux); err != nil { + log.Fatal(err) + } +} diff --git a/test/tools/gunfish-cli/gunfish-cli.go b/test/tools/gunfish-cli/gunfish-cli.go new file mode 100644 index 0000000..0c68153 --- /dev/null +++ b/test/tools/gunfish-cli/gunfish-cli.go @@ -0,0 +1,161 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "strings" +) + +func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + + var ( + typ string + port int + host string + apnsTopic string + count int + message string + sound string + options string + token string + dryrun bool + verbose bool + jsonFile string + ) + + flag.StringVar(&typ, "type", "apns", "push notification type. 'apns' or 'fcm' (fcm not implemented)") + flag.IntVar(&count, "count", 1, "send count") + flag.IntVar(&port, "port", 8003, "gunfish port") + flag.StringVar(&host, "host", "localhost", "gunfish host") + flag.StringVar(&apnsTopic, "apns-topic", "", "apns topic") + flag.StringVar(&message, "message", "test notification", "push notification message") + flag.StringVar(&sound, "sound", "default", "push notification sound (default: 'default')") + flag.StringVar(&options, "options", "", "options (key1=value1,key2=value2...)") + flag.StringVar(&token, "token", "", "apns device token (required)") + flag.BoolVar(&dryrun, "dryrun", false, "dryrun") + flag.BoolVar(&verbose, "verbose", false, "dryrun") + flag.StringVar(&jsonFile, "json-file", "", "json input file") + + flag.Parse() + + switch typ { + case "apns": + // OK + case "fcm": + return errors.New("[ERROR] not implemented") + default: + return errors.New("[ERROR] wrong push notification type") + } + + if verbose { + log.Printf("host: %s, port: %d, send count: %d", host, port, count) + } + + opts := map[string]string{} + payloads := make([]map[string]interface{}, count) + if jsonFile == "" { + if options != "" { + for _, opt := range strings.Split(options, ",") { + kv := strings.Split(opt, "=") + key, val := kv[0], kv[1] + opts[key] = val + } + } + + for i := 0; i < count; i++ { + payloads[i] = map[string]interface{}{} + payloads[i] = buildPayload(token, message, sound, apnsTopic, opts) + } + } + + if dryrun { + log.Println("[dryrun] checks request payload:") + if jsonFile == "" { + out, err := json.MarshalIndent(payloads, "", " ") + if err != nil { + return err + } + fmt.Println(string(out)) + } + + out, err := ioutil.ReadFile(jsonFile) + if err != nil { + return err + } + fmt.Println(string(out)) + return nil + } + + if verbose && len(payloads) > 0 { + log.Printf("post data: %#v", payloads) + } + endpoint := fmt.Sprintf("http://%s:%d/push/apns", host, port) + req, err := newRequest(endpoint, jsonFile, payloads) + if err != nil { + return err + } + req.Header.Set("content-type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + out, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + fmt.Println(string(out)) + return nil +} + +func buildPayload(token, message, sound, apnsTopic string, opts map[string]string) map[string]interface{} { + payload := map[string]interface{}{ + "aps": map[string]string{ + "alert": message, + "sound": sound, + }, + } + for k, v := range opts { + payload[k] = v + } + + return map[string]interface{}{ + "payload": payload, + "token": token, + "header": map[string]interface{}{ + "apns-topic": apnsTopic, + }, + } +} + +func newRequest(endpoint, jsonFile string, payloads []map[string]interface{}) (*http.Request, error) { + if jsonFile == "" { + b := &bytes.Buffer{} + err := json.NewEncoder(b).Encode(payloads) + if err != nil { + return nil, err + } + return http.NewRequest(http.MethodPost, endpoint, b) + } + + b, err := os.Open(jsonFile) + if err != nil { + return nil, err + } + return http.NewRequest(http.MethodPost, endpoint, b) +}