Skip to content

Commit

Permalink
pass context.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
fujiwara committed Oct 18, 2024
1 parent 30c4ed0 commit 2c2856d
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test:
go test

build:
goreleaser build --snapshot --rm-dist
goreleaser build --snapshot --clean

clean:
rm -rf dist/*
Expand Down
3 changes: 2 additions & 1 deletion cmd/irc-msgr/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"flag"
"log"
"os"
Expand All @@ -25,7 +26,7 @@ func main() {
log.SetOutput(filter)

flag.Parse()
err := nopaste.RunMsgr(config)
err := nopaste.RunMsgr(context.Background(), config)
if err != nil {
panic(err)
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/nopaste/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"flag"
"log"
"os"
Expand All @@ -23,7 +24,7 @@ func main() {
}
log.SetOutput(filter)

err := nopaste.Run(config)
err := nopaste.Run(context.Background(), config)
if err != nil {
panic(err)
}
Expand Down
13 changes: 9 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package nopaste

import (
"context"
"errors"
"io/ioutil"
"log"
"os"
"path"

goconfig "github.com/kayac/go-config"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (c *Config) Save() error {
if err != nil {
return err
}
return ioutil.WriteFile(c.filePath, data, 0644)
return os.WriteFile(c.filePath, data, 0644)
}

func (c *Config) SetFilePath(path string) {
Expand All @@ -71,7 +72,7 @@ func (c *Config) Storages() []Storage {
return c.storages
}

func LoadConfig(file string) (*Config, error) {
func LoadConfig(ctx context.Context, file string) (*Config, error) {
log.Println("[info] loading config file", file)
c := Config{filePath: file}
err := goconfig.LoadWithEnv(&c, file)
Expand All @@ -81,7 +82,11 @@ func LoadConfig(file string) (*Config, error) {

if c.S3 != nil {
log.Printf("[info] using S3 storage s3://%s", path.Join(c.S3.Bucket, c.S3.KeyPrefix))
c.storages = append(c.storages, NewS3Storage(c.S3))
s, err := NewS3Storage(ctx, c.S3)
if err != nil {
return nil, err
}
c.storages = append(c.storages, s)
}
if c.DataDir != "" {
log.Printf("[info] using local storage %s", c.DataDir)
Expand Down
3 changes: 2 additions & 1 deletion config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nopaste_test

import (
"context"
"fmt"
"os"
"testing"
Expand All @@ -11,7 +12,7 @@ import (
func TestLoadConfig(t *testing.T) {
base := "http://nopaste.example.com"
os.Setenv("BASE_URL", base)
c, err := nopaste.LoadConfig("test/example.yaml")
c, err := nopaste.LoadConfig(context.Background(), "test/example.yaml")
if err != nil {
t.Error(err)
}
Expand Down
5 changes: 3 additions & 2 deletions msgr.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package nopaste

import (
"context"
"log"
"net/http"
)

const MsgrRoot = "/irc-msgr"

func RunMsgr(configFile string) error {
func RunMsgr(ctx context.Context, configFile string) error {
var err error
config, err = LoadConfig(configFile)
config, err = LoadConfig(ctx, configFile)
if err != nil {
return err
}
Expand Down
28 changes: 16 additions & 12 deletions nopaste.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ type nopasteContent struct {
LinkNames int
}

func Run(configFile string) error {
func Run(ctx context.Context, configFile string) error {
var err error
config, err = LoadConfig(configFile)
config, err = LoadConfig(ctx, configFile)
if err != nil {
return err
}
Expand Down Expand Up @@ -80,7 +80,7 @@ func rootHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
if _notice := req.FormValue("notice"); _notice == "0" {
np.LinkNames = 1
}
path, code := saveContent(np, chs)
path, code := saveContent(req.Context(), np, chs)
if code == http.StatusFound {
http.Redirect(w, req, path, code)
} else {
Expand Down Expand Up @@ -109,7 +109,7 @@ func serveHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
var f io.ReadCloser
var err error
for _, s := range config.Storages() {
if _f, _err := s.Load(id); _err == nil {
if _f, _err := s.Load(req.Context(), id); _err == nil {
log.Println("[debug] loaded from", s)
f, err = _f, _err
break
Expand All @@ -126,7 +126,7 @@ func serveHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
io.Copy(w, f)
}

func saveContent(np nopasteContent, chs []MessageChan) (string, int) {
func saveContent(ctx context.Context, np nopasteContent, chs []MessageChan) (string, int) {
if np.Text == "" {
log.Println("[warn] empty text")
return Root, http.StatusFound
Expand All @@ -136,7 +136,7 @@ func saveContent(np nopasteContent, chs []MessageChan) (string, int) {
id := hex[0:10]
log.Println("[info] save", id)

err := config.Storages()[0].Save(id, data)
err := config.Storages()[0].Save(ctx, id, data)
if err != nil {
log.Println("[warn]", err)
return Root, 500
Expand Down Expand Up @@ -174,6 +174,7 @@ type HttpNotification struct {
}

func snsHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
ctx := req.Context()
if req.Method != "POST" {
serverError(w, 400)
return
Expand Down Expand Up @@ -201,17 +202,20 @@ func snsHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
IconEmoji: ":amazonsns:",
Nick: "AmazonSNS",
}
saveContent(np, chs)
saveContent(ctx, np, chs)
case "SubscriptionConfirmation", "Notification":
if n.Type == "SubscriptionConfirmation" {
region, _ := getRegionFromARN(n.TopicArn)
snsSvc := NewSNS(region)
_, err := snsSvc.ConfirmSubscription(context.TODO(), &sns.ConfirmSubscriptionInput{
snsSvc, err := NewSNS(ctx, region)
if err != nil {
log.Println("[warn]", err)
break
}
if _, err := snsSvc.ConfirmSubscription(ctx, &sns.ConfirmSubscriptionInput{
Token: aws.String(n.Token),
TopicArn: aws.String(n.TopicArn),
AuthenticateOnUnsubscribe: aws.String("no"),
})
if err != nil {
}); err != nil {
log.Println("[warn]", err)
break
}
Expand Down Expand Up @@ -241,7 +245,7 @@ func snsHandler(w http.ResponseWriter, req *http.Request, chs []MessageChan) {
IconEmoji: ":amazonsns:",
Nick: "AmazonSNS",
}
saveContent(np, chs)
saveContent(ctx, np, chs)
}
io.WriteString(w, "OK")
}
9 changes: 5 additions & 4 deletions sns.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ package nopaste
import (
"context"
"errors"
"fmt"
"strings"

awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sns"
)

func NewSNS(region string) *sns.Client {
cfg, err := awsConfig.LoadDefaultConfig(context.TODO())
func NewSNS(ctx context.Context, region string) (*sns.Client, error) {
cfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region))
if err != nil {
panic(err)
return nil, fmt.Errorf("failed to load config, %v", err)
}
return sns.NewFromConfig(cfg)
return sns.NewFromConfig(cfg), nil
}

func getRegionFromARN(arn string) (string, error) {
Expand Down
32 changes: 17 additions & 15 deletions storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
)

type Storage interface {
Save(string, []byte) error
Load(string) (io.ReadCloser, error)
Save(context.Context, string, []byte) error
Load(context.Context, string) (io.ReadCloser, error)
}

type LocalStorage struct {
Expand All @@ -30,13 +30,13 @@ func NewLocalStorage(datadir string) *LocalStorage {
}
}

func (s *LocalStorage) Save(name string, data []byte) error {
func (s *LocalStorage) Save(_ context.Context, name string, data []byte) error {
f := filepath.Join(s.DataDir, name+".txt")
log.Println("[debug] save to", f)
return os.WriteFile(f, data, 0644)
}

func (s *LocalStorage) Load(name string) (io.ReadCloser, error) {
func (s *LocalStorage) Load(_ context.Context, name string) (io.ReadCloser, error) {
f := filepath.Join(s.DataDir, name+".txt")
log.Println("[debug] load from", f)
return os.Open(f)
Expand All @@ -48,25 +48,27 @@ type S3Storage struct {
svc *s3.Client
}

func NewS3Storage(c *S3Config) *S3Storage {
cfg, err := awsConfig.LoadDefaultConfig(context.TODO())
func NewS3Storage(ctx context.Context, c *S3Config) (*S3Storage, error) {
cfg, err := awsConfig.LoadDefaultConfig(ctx)
if err != nil {
panic(err)
return nil, fmt.Errorf("failed to load config, %v", err)
}
svc := s3.NewFromConfig(cfg)
return &S3Storage{
Bucket: c.Bucket,
KeyPrefix: c.KeyPrefix,
svc: svc,
}
}, nil
}

func (s *S3Storage) Load(name string) (io.ReadCloser, error) {
func (s *S3Storage) Load(ctx context.Context, name string) (io.ReadCloser, error) {
for _, name := range []string{s.objectName(name), name} {
result, err := s.svc.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(s.Bucket),
Key: aws.String(path.Join(s.KeyPrefix, name)),
})
result, err := s.svc.GetObject(ctx,
&s3.GetObjectInput{
Bucket: aws.String(s.Bucket),
Key: aws.String(path.Join(s.KeyPrefix, name)),
},
)
log.Printf("[debug] load from s3://%s", path.Join(s.Bucket, s.KeyPrefix, name))
if err == nil {
log.Printf("[debug] result %v", result)
Expand All @@ -76,7 +78,7 @@ func (s *S3Storage) Load(name string) (io.ReadCloser, error) {
return nil, fmt.Errorf("not found %s and %s", name, s.objectName(name))
}

func (s *S3Storage) Save(name string, b []byte) error {
func (s *S3Storage) Save(ctx context.Context, name string, b []byte) error {
name = s.objectName(name)
input := &s3.PutObjectInput{
Body: bytes.NewReader(b),
Expand All @@ -85,7 +87,7 @@ func (s *S3Storage) Save(name string, b []byte) error {
ContentType: aws.String("text/plain"),
}
log.Printf("[debug] save to s3://%s", path.Join(s.Bucket, s.KeyPrefix, name))
_, err := s.svc.PutObject(context.TODO(), input)
_, err := s.svc.PutObject(ctx, input)
return err
}

Expand Down

0 comments on commit 2c2856d

Please sign in to comment.