From db0e58ae7ed7143f3c5c2ecd9009d82cf0b0c844 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 22 Oct 2022 11:56:41 +0200 Subject: [PATCH] Add support for graceful shutdown Fixes #1014 Signed-off-by: Nicola Murino --- docs/full-configuration.md | 1 + docs/portable-mode.md | 6 ++ docs/post-connect-hook.md | 2 - go.mod | 38 +++++------ go.sum | 75 +++++++++++----------- internal/cmd/install_windows.go | 4 ++ internal/cmd/portable.go | 14 +++- internal/cmd/root.go | 20 ++++++ internal/cmd/serve.go | 1 + internal/cmd/start_windows.go | 1 + internal/common/actions.go | 4 ++ internal/common/common.go | 95 ++++++++++++++++++++++++---- internal/common/common_test.go | 38 +++++------ internal/common/connection.go | 9 ++- internal/common/connection_test.go | 7 ++ internal/common/eventmanager.go | 2 + internal/common/protocol_test.go | 90 +++++++++++++++++++++++++- internal/ftpd/server.go | 6 +- internal/httpd/httpd_test.go | 6 +- internal/httpd/server.go | 6 +- internal/plugin/plugin.go | 4 +- internal/service/service.go | 13 +++- internal/service/service_portable.go | 4 ++ internal/service/service_windows.go | 1 + internal/service/signals_unix.go | 1 + internal/service/signals_windows.go | 2 + internal/sftpd/internal_test.go | 9 +++ internal/sftpd/scp.go | 14 +--- internal/sftpd/server.go | 4 +- internal/sftpd/sftpd_test.go | 9 ++- internal/sftpd/ssh_cmd.go | 3 + internal/webdavd/server.go | 6 +- 32 files changed, 371 insertions(+), 124 deletions(-) diff --git a/docs/full-configuration.md b/docs/full-configuration.md index c1f1220d2..9b167a10c 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -31,6 +31,7 @@ The `serve` command supports the following flags: - `--config-dir` string. Location of the config dir. This directory is used as the base for files with a relative path, eg. the private keys for the SFTP server or the SQLite database if you use SQLite as data provider. The configuration file, if not explicitly set, is looked for in this dir. We support reading from JSON, TOML, YAML, HCL, envfile and Java properties config files. The default config file name is `sftpgo` and therefore `sftpgo.json`, `sftpgo.yaml` and so on are searched. The default value is the working directory (".") or the value of `SFTPGO_CONFIG_DIR` environment variable. - `--config-file` string. This flag explicitly defines the path, name and extension of the config file. If must be an absolute path or a path relative to the configuration directory. The specified file name must have a supported extension (JSON, YAML, TOML, HCL or Java properties). The default value is empty or the value of `SFTPGO_CONFIG_FILE` environment variable. +- `--grace-time`, integer. Graceful shutdown is an option to initiate a shutdown without abrupt cancellation of the currently ongoing client-initiated transfer sessions. This grace time defines the number of seconds allowed for existing transfers to get completed before shutting down. 0 means disabled. The default value is `0` or the value of `SFTPGO_GRACE_TIME` environment variable. A graceful shutdown is triggered by an interrupt signal or by a service `stop` request on Windows, if a grace time is configured. - `--loaddata-from` string. Load users and folders from this file. The file must be specified as absolute path and it must contain a backup obtained using the `dumpdata` REST API or compatible content. The default value is empty or the value of `SFTPGO_LOADDATA_FROM` environment variable. - `--loaddata-clean` boolean. Determine if the loaddata-from file should be removed after a successful load. Default `false` or the value of `SFTPGO_LOADDATA_CLEAN` environment variable (1 or `true`, 0 or `false`). - `--loaddata-mode`, integer. Restore mode for data to load. 0 means new users are added, existing users are updated. 1 means new users are added, existing users are not modified. Default 1 or the value of `SFTPGO_LOADDATA_MODE` environment variable. diff --git a/docs/portable-mode.md b/docs/portable-mode.md index d5c3620cb..f935fbc41 100644 --- a/docs/portable-mode.md +++ b/docs/portable-mode.md @@ -74,6 +74,12 @@ Flags: virtual folder identified by this prefix and its contents --gcs-storage-class string + --grace-time int This grace time defines the number of + seconds allowed for existing transfers + to get completed before shutting down. + A graceful shutdown is triggered by an + interrupt signal. + -h, --help help for portable -l, --log-file-path string Leave empty to disable logging --log-level string Set the log level. diff --git a/docs/post-connect-hook.md b/docs/post-connect-hook.md index c2b0352c3..076a65986 100644 --- a/docs/post-connect-hook.md +++ b/docs/post-connect-hook.md @@ -2,8 +2,6 @@ This hook is executed as soon as a new connection is established. It notifies the connection's IP address and protocol. Based on the received response, the connection is accepted or rejected. Combining this hook with the [Post-login hook](./post-login-hook.md) you can implement your own (even for Protocol) blacklist/whitelist of IP addresses. -Please keep in mind that you can easily configure specialized program such as [Fail2ban](http://www.fail2ban.org/) for brute force protection. Executing a hook for each connection can be heavy. - The `post_connect_hook` can be defined as the absolute path of your program or an HTTP URL. If the hook defines an external program it can read the following environment variables: diff --git a/go.mod b/go.mod index bb01306a5..346b9d42b 100644 --- a/go.mod +++ b/go.mod @@ -8,15 +8,15 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1 github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387 - github.com/aws/aws-sdk-go-v2 v1.16.16 - github.com/aws/aws-sdk-go-v2/config v1.17.8 - github.com/aws/aws-sdk-go-v2/credentials v1.12.21 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.35 - github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.19 - github.com/aws/aws-sdk-go-v2/service/s3 v1.28.0 - github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.2 - github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 + github.com/aws/aws-sdk-go-v2 v1.17.0 + github.com/aws/aws-sdk-go-v2/config v1.17.9 + github.com/aws/aws-sdk-go-v2/credentials v1.12.22 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.18 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.36 + github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.20 + github.com/aws/aws-sdk-go-v2/service/s3 v1.29.0 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.17.0 github.com/cockroachdb/cockroach-go/v2 v2.2.16 github.com/coreos/go-oidc/v3 v3.4.0 github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 @@ -82,16 +82,16 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.1 // indirect github.com/ajg/form v1.5.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.8 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.24 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.14 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.24 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.15 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.9 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.18 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.17 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.17 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.11.23 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.11.24 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.7 // indirect github.com/aws/smithy-go v1.13.3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.0.1 // indirect @@ -156,7 +156,7 @@ require ( go.opencensus.io v0.23.0 // indirect golang.org/x/mod v0.6.0 // indirect golang.org/x/text v0.4.0 // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/tools v0.2.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20221018160656-63c7b68cfc55 // indirect diff --git a/go.sum b/go.sum index 13899e84a..fdc36a50a 100644 --- a/go.sum +++ b/go.sum @@ -224,67 +224,67 @@ github.com/aws/aws-sdk-go v1.44.45/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4 github.com/aws/aws-sdk-go v1.44.68/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.16.8/go.mod h1:6CpKuLXg2w7If3ABZCl/qZ6rEgwtjZTn4eAf4RcEyuw= -github.com/aws/aws-sdk-go-v2 v1.16.16 h1:M1fj4FE2lB4NzRb9Y0xdWsn2P0+2UHVxwKyOa4YJNjk= -github.com/aws/aws-sdk-go-v2 v1.16.16/go.mod h1:SwiyXi/1zTUZ6KIAmLK5V5ll8SiURNUYOqTerZPaF9k= +github.com/aws/aws-sdk-go-v2 v1.17.0 h1:kWm8OZGx0Zvd6PsOfjFtwbw7+uWYp65DK8suo7WVznw= +github.com/aws/aws-sdk-go-v2 v1.17.0/go.mod h1:SwiyXi/1zTUZ6KIAmLK5V5ll8SiURNUYOqTerZPaF9k= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.3/go.mod h1:gNsR5CaXKmQSSzrmGxmwmct/r+ZBfbxorAuXYsj/M5Y= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.8 h1:tcFliCWne+zOuUfKNRn8JdFBuWPDuISDH08wD2ULkhk= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.8/go.mod h1:JTnlBSot91steJeti4ryyu/tLd4Sk84O5W22L7O2EQU= github.com/aws/aws-sdk-go-v2/config v1.15.15/go.mod h1:A1Lzyy/o21I5/s2FbyX5AevQfSVXpvvIDCoVFD0BC4E= -github.com/aws/aws-sdk-go-v2/config v1.17.8 h1:b9LGqNnOdg9vR4Q43tBTVWk4J6F+W774MSchvKJsqnE= -github.com/aws/aws-sdk-go-v2/config v1.17.8/go.mod h1:UkCI3kb0sCdvtjiXYiU4Zx5h07BOpgBTtkPu/49r+kA= +github.com/aws/aws-sdk-go-v2/config v1.17.9 h1:PyqFD7DTmOx5gdvjFwZH2Tx0vivy+cJdM3SE3NVoWZc= +github.com/aws/aws-sdk-go-v2/config v1.17.9/go.mod h1:NGC2Ut1x1Gl+qBdh4uGdqRTDtk6f3qS8VQ45kEoyAvM= github.com/aws/aws-sdk-go-v2/credentials v1.12.10/go.mod h1:g5eIM5XRs/OzIIK81QMBl+dAuDyoLN0VYaLP+tBqEOk= -github.com/aws/aws-sdk-go-v2/credentials v1.12.21 h1:4tjlyCD0hRGNQivh5dN8hbP30qQhMLBE/FgQR1vHHWM= -github.com/aws/aws-sdk-go-v2/credentials v1.12.21/go.mod h1:O+4XyAt4e+oBAoIwNUYkRg3CVMscaIJdmZBOcPgJ8D8= +github.com/aws/aws-sdk-go-v2/credentials v1.12.22 h1:HPig9ugqH7Eyf2aqNVAPOCp3L/N2vlQ/IiaTxwcrH8U= +github.com/aws/aws-sdk-go-v2/credentials v1.12.22/go.mod h1:XfHZqa+J1j2Am2GHrsWtg24tnkFkKxmWbWWel+W1zp0= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.9/go.mod h1:KDCCm4ONIdHtUloDcFvK2+vshZvx4Zmj7UMDfusuz5s= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 h1:r08j4sbZu/RVi+BNxkBJwPMUYY3P8mgSDuKkZ/ZN1lE= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17/go.mod h1:yIkQcCDYNsZfXpd5UX2Cy+sWA1jPgIhGTw9cOBzfVnQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.18 h1:63dqlW4EI4nfhmXJOUqP0zIaGEHoRPn1ahLz8hUOWrQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.18/go.mod h1:O3tSoDcot3jy62HNmq7ms16dPHQMR6nqQxooj8T53tI= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.21/go.mod h1:iIYPrQ2rYfZiB/iADYlhj9HHZ9TTi6PqKQPAqygohbE= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.35 h1:vors9KQrDxcobmg5EAdgqBlAw9RclaVlS9uIb5JKZC0= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.35/go.mod h1:4m/hcx6qeabg+3q/v0VAuYGlnY5hWv53egqjT/d1lFU= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.36 h1:DYIvpSIM9YTdid6yRZk/w2kJhJJIbFnL/76NfzmfaTs= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.36/go.mod h1:1vzWYwKGRitVzk7xD3y8Ko7lg26qX+Pxwb5uRaOPSlM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.15/go.mod h1:pWrr2OoHlT7M/Pd2y4HV3gJyPb3qj5qMmnPkKSNPYK4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 h1:s4g/wnzMf+qepSNgTvaQQHNxyMLKSawNhKCPNy++2xY= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23/go.mod h1:2DFxAQ9pfIRy0imBCJv+vZ2X6RKxves6fbnEuSry6b4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.24 h1:WFIoN2kiF95/4z4HNcJ9F9B0xFV0vrPlUOf3+uNIujM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.24/go.mod h1:ghMzB/j2wRbPx5/4jPYxJdOtCG2ggrtY01j8K7FMBDA= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.9/go.mod h1:08tUpeSGN33QKSO7fwxXczNfiwCpbj+GxK6XKwqWVv0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17 h1:/K482T5A3623WJgWT8w1yRAFK4RzGzEl7y39yhtn9eA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17/go.mod h1:pRwaTYCJemADaqCbUAxltMoHKata7hmB5PjEXeu0kfg= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.18 h1:c2RKF0UvfdVI6epHtFjDujlbiK+VeY85dP1i4gmYc5w= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.18/go.mod h1:fkQKYK/jUhCL/wNS1tOPrlYhr9vqutjCz4zZC1wBE1s= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.16/go.mod h1:CYmI+7x03jjJih8kBEEFKRQc40UjUokT0k7GbvrhhTc= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.24 h1:wj5Rwc05hvUSvKuOF29IYb9QrCLjU+rHAy/x/o0DK2c= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.24/go.mod h1:jULHjqqjDlbyTa7pfM7WICATnOv+iOhjletM3N0Xbu8= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.25 h1:q4TXoep+lPTJneYxlIdcBrlGmTrhfNwrfkdBt1+HqzA= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.25/go.mod h1:9uX0Ksj6Zmsd3iQIyVkwkPWUqhPF6TxT/t8zYwUiQEU= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.6/go.mod h1:O7Oc4peGZDEKlddivslfYFvAbgzvl/GH3J8j3JIGBXc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.14 h1:ZSIPAkAsCCjYrhqfw2+lNzWDzxzHXEckFkTePL5RSWQ= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.14/go.mod h1:AyGgqiKv9ECM6IZeNQtdT8NnMvUb3/2wokeq2Fgryto= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.15 h1:15q0OjFjny5qjCC8nI+4DH+MZFDC2/BtXxONBNnVZR8= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.15/go.mod h1:t7/Pw0mlxveHXyfzEkGjzQ59Xu9xUmzOfxe1S52TJ8Q= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.3/go.mod h1:gkb2qADY+OHaGLKNTYxMaQNacfeyQpZ4csDTQMeFmcw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.9 h1:Lh1AShsuIJTwMkoxVCAYPJgNG5H+eN6SmoUn8nOZ5wE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.9/go.mod h1:a9j48l6yL5XINLHLcOKInjdvknN+vWqPBxqeIDw7ktw= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.10/go.mod h1:Qks+dxK3O+Z2deAhNo6cJ8ls1bam3tUGUAcgxQP1c70= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.18 h1:BBYoNQt2kUZUUK4bIPsKrCcjVPUMNsgQpNAwhznK/zo= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.18/go.mod h1:NS55eQ4YixUJPTC+INxi2/jCqe1y2Uw3rnh9wEOVJxY= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.19 h1:jrV+VRNrUuzcwTZxdZMi1JtKMk71FN1H7VaF8XjGl44= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.19/go.mod h1:HGDDjLf/IyINXk4PcEZSEviZulqnePG76iq9/rC5qqo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.9/go.mod h1:yQowTpvdZkFVuHrLBXmczat4W+WJKg/PafBZnGBLga0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.17 h1:Jrd/oMh0PKQc6+BowB+pLEwLIgaQF29eYbe7E1Av9Ug= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.17/go.mod h1:4nYOrY41Lrbk2170/BGkcJKBhws9Pfn8MG3aGqjjeFI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.18 h1:5oiCDEOHnYkk7uTVI8Wv6ftdFfb6YlUUNzkeePVIPjY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.18/go.mod h1:QtCDHDOXunxeihz7iU15e09u9gRIeaa5WeE6FZVnGUo= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.9/go.mod h1:Rc5+wn2k8gFSi3V1Ch4mhxOzjMh+bYSXVFfVaqowQOY= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.17 h1:HfVVR1vItaG6le+Bpw6P4midjBDMKnjMyZnw9MXYUcE= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.17/go.mod h1:YqMdV+gEKCQ59NrB7rzrJdALeBIsYiVi8Inj3+KcqHI= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.18 h1:sk9Z5ZwZpLGq3q8ZhOsw8bORT2t8raWPsFrq/yMMbZ0= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.18/go.mod h1:O1mfO/JzWKUNujOAqD39r7BXqlvhjh/JiPnQ97tvQMc= github.com/aws/aws-sdk-go-v2/service/kms v1.18.1/go.mod h1:4PZMUkc9rXHWGVB5J9vKaZy3D7Nai79ORworQ3ASMiM= -github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.19 h1:6rxMT+zWZh2+0F1XHdDWCSzuMQIJI+tGlfrFi6V/UlU= -github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.19/go.mod h1:wGzRNLBD3V8/KKoBSYz0OWv1dnQNvqTyb193fS97dXQ= +github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.20 h1:jOpM3C6a/W4cd31hj3qok1NZKu3pWYLEg5IwUharV+o= +github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.20/go.mod h1:pvYIQ3quYKA9wXvn5oY6Suu4RqjURwN1tERJssL57nQ= github.com/aws/aws-sdk-go-v2/service/s3 v1.27.2/go.mod h1:u+566cosFI+d+motIz3USXEh6sN8Nq4GrNXSg2RXVMo= -github.com/aws/aws-sdk-go-v2/service/s3 v1.28.0 h1:2TDTNMeOdEBVhuHPS6at9eqAPdco4A1iwRO5tov9Ylg= -github.com/aws/aws-sdk-go-v2/service/s3 v1.28.0/go.mod h1:fmgDANqTUCxciViKl9hb/zD5LFbvPINFRgWhDbR+vZo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.29.0 h1:wmROdhyusq7m7HJgSB9Jm955XU4Kvz0FknIbr1dJTjA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.29.0/go.mod h1:syhASH3D6eA1PCga49mGfvISJh/E2QYaooSIqir3pIM= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.14/go.mod h1:xakbH8KMsQQKqzX87uyyzTHshc/0/Df8bsTneTS5pFU= -github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.2 h1:3x1Qilin49XQ1rK6pDNAfG+DmCFPfB7Rrpl+FUDAR/0= -github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.2/go.mod h1:HEBBc70BYi5eUvxBqC3xXjU/04NO96X/XNUe5qhC7Bc= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.3 h1:d5S+OhXne5O3cIo999RARy/N1dgXW2ldWgD53qbEAP4= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.16.3/go.mod h1:+X/VSQcuvHPWPRlM64HoWUJAPwsD86KpU9Z52lrsodM= github.com/aws/aws-sdk-go-v2/service/sns v1.17.10/go.mod h1:uITsRNVMeCB3MkWpXxXw0eDz8pW4TYLzj+eyQtbhSxM= github.com/aws/aws-sdk-go-v2/service/sqs v1.19.1/go.mod h1:A94o564Gj+Yn+7QO1eLFeI7UVv3riy/YBFOfICVqFvU= github.com/aws/aws-sdk-go-v2/service/ssm v1.27.6/go.mod h1:fiFzQgj4xNOg4/wqmAiPvzgDMXPD+cUEplX/CYn+0j0= github.com/aws/aws-sdk-go-v2/service/sso v1.11.13/go.mod h1:d7ptRksDDgvXaUvxyHZ9SYh+iMDymm94JbVcgvSYSzU= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.23 h1:pwvCchFUEnlceKIgPUouBJwK81aCkQ8UDMORfeFtW10= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.23/go.mod h1:/w0eg9IhFGjGyyncHIQrXtU8wvNsTJOP0R6PPj0wf80= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.6 h1:OwhhKc1P9ElfWbMKPIbMMZBV6hzJlL2JKD76wNNVzgQ= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.6/go.mod h1:csZuQY65DAdFBt1oIjO5hhBR49kQqop4+lcuCjf2arA= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.24 h1:tNfD0JI7VKcIcEzYeIAXCIr8qnoq6DACg3QRt50ofOY= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.24/go.mod h1:7ZC+G3rX2IsGKIhiGDFiul7rgZPApvFy3dDJO7wKtno= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.7 h1:q2FDE8cl8rTPqgrTT0dF7xzIfGAwLMh2P+nU7F2CqVs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.7/go.mod h1:sPh8yf7vmBOI/L9fqP55uq+T9WVoxnqrHMqyvgYC/gA= github.com/aws/aws-sdk-go-v2/service/sts v1.16.10/go.mod h1:cftkHYN6tCDNfkSasAmclSfl4l7cySoay8vz7p/ce0E= -github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 h1:9pPi0PsFNAGILFfPCk8Y0iyEBGc6lu6OQ97U7hmdesg= -github.com/aws/aws-sdk-go-v2/service/sts v1.16.19/go.mod h1:h4J3oPZQbxLhzGnk+j9dfYHi5qIOVJ5kczZd658/ydM= +github.com/aws/aws-sdk-go-v2/service/sts v1.17.0 h1:9S0HcZUxKcU3HdN+M6GgLIvdbg9as5aOoHrvwRsPNYU= +github.com/aws/aws-sdk-go-v2/service/sts v1.17.0/go.mod h1:9pZN58zQc5a4Dkdnhu/rI1lNBui1vP5B0giGCuUt2b0= github.com/aws/smithy-go v1.12.0/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/aws/smithy-go v1.13.3 h1:l7LYxGuzK6/K+NzJ2mC+VvLUbae0sL3bXU//04MkmnA= github.com/aws/smithy-go v1.13.3/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= @@ -2042,8 +2042,9 @@ golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyj golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/cmd/install_windows.go b/internal/cmd/install_windows.go index 02fa174eb..fead83103 100644 --- a/internal/cmd/install_windows.go +++ b/internal/cmd/install_windows.go @@ -109,5 +109,9 @@ func getCustomServeFlags() []string { if logCompress != defaultLogCompress { result = append(result, "--"+logCompressFlag+"=true") } + if graceTime != defaultGraceTime { + result = append(result, "--"+graceTimeFlag) + result = append(result, strconv.Itoa(graceTime)) + } return result } diff --git a/internal/cmd/portable.go b/internal/cmd/portable.go index 01976731f..b0377f801 100644 --- a/internal/cmd/portable.go +++ b/internal/cmd/portable.go @@ -169,6 +169,7 @@ Please take a look at the usage below to customize the serving parameters`, os.Exit(1) } } + service.SetGraceTime(graceTime) service := service.Service{ ConfigDir: filepath.Clean(defaultConfigDir), ConfigFile: defaultConfigFile, @@ -257,8 +258,10 @@ Please take a look at the usage below to customize the serving parameters`, }, }, } - if err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands, portableAdvertiseService, - portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableWebDAVKey); err == nil { + err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands, + portableAdvertiseService, portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert, + portableWebDAVKey) + if err == nil { service.Wait() if service.Error == nil { os.Exit(0) @@ -403,6 +406,13 @@ multiple concurrent requests and this allows data to be transferred at a faster rate, over high latency networks, by overlapping round-trip times`) + portableCmd.Flags().IntVar(&graceTime, graceTimeFlag, 0, + `This grace time defines the number of +seconds allowed for existing transfers +to get completed before shutting down. +A graceful shutdown is triggered by an +interrupt signal. +`) rootCmd.AddCommand(portableCmd) } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index f6463258d..06e8f29ba 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -52,6 +52,8 @@ const ( loadDataQuotaScanKey = "loaddata_scan" loadDataCleanFlag = "loaddata-clean" loadDataCleanKey = "loaddata_clean" + graceTimeFlag = "grace-time" + graceTimeKey = "grace_time" defaultConfigDir = "." defaultConfigFile = "" defaultLogFile = "sftpgo.log" @@ -65,6 +67,7 @@ const ( defaultLoadDataMode = 1 defaultLoadDataQuotaScan = 0 defaultLoadDataClean = false + defaultGraceTime = 0 ) var ( @@ -81,6 +84,7 @@ var ( loadDataMode int loadDataQuotaScan int loadDataClean bool + graceTime int // used if awscontainer build tag is enabled disableAWSInstallationCode bool @@ -262,4 +266,20 @@ This flag can be set using SFTPGO_LOADDATA_QUOTA_SCAN env var too. (default 0)`) viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck + + viper.SetDefault(graceTimeKey, defaultGraceTime) + viper.BindEnv(graceTimeKey, "SFTPGO_GRACE_TIME") //nolint:errcheck + cmd.Flags().IntVar(&graceTime, graceTimeFlag, viper.GetInt(graceTimeKey), + `Graceful shutdown is an option to initiate a +shutdown without abrupt cancellation of the +currently ongoing client-initiated transfer +sessions. +This grace time defines the number of seconds +allowed for existing transfers to get +completed before shutting down. +A graceful shutdown is triggered by an +interrupt signal. +This flag can be set using SFTPGO_GRACE_TIME env +var too. 0 means disabled. (default 0)`) + viper.BindPFlag(graceTimeKey, cmd.Flags().Lookup(graceTimeFlag)) //nolint:errcheck } diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go index 1b7373f69..07189b078 100644 --- a/internal/cmd/serve.go +++ b/internal/cmd/serve.go @@ -34,6 +34,7 @@ $ sftpgo serve Please take a look at the usage below to customize the startup options`, Run: func(_ *cobra.Command, _ []string) { + service.SetGraceTime(graceTime) service := service.Service{ ConfigDir: util.CleanDirInput(configDir), ConfigFile: configFile, diff --git a/internal/cmd/start_windows.go b/internal/cmd/start_windows.go index 7edd391e3..1c1cfba90 100644 --- a/internal/cmd/start_windows.go +++ b/internal/cmd/start_windows.go @@ -34,6 +34,7 @@ var ( if !filepath.IsAbs(logFilePath) && util.IsFileInputValid(logFilePath) { logFilePath = filepath.Join(configDir, logFilePath) } + service.SetGraceTime(graceTime) s := service.Service{ ConfigDir: configDir, ConfigFile: configFile, diff --git a/internal/common/actions.go b/internal/common/actions.go index f8bb5f03a..1729fd092 100644 --- a/internal/common/actions.go +++ b/internal/common/actions.go @@ -26,6 +26,7 @@ import ( "path" "path/filepath" "strings" + "sync/atomic" "time" "github.com/sftpgo/sdk" @@ -44,13 +45,16 @@ var ( errNoHook = errors.New("unable to execute action, no hook defined") errUnexpectedHTTResponse = errors.New("unexpected HTTP hook response code") hooksConcurrencyGuard = make(chan struct{}, 150) + activeHooks atomic.Int32 ) func startNewHook() { + activeHooks.Add(1) hooksConcurrencyGuard <- struct{}{} } func hookEnded() { + activeHooks.Add(-1) <-hooksConcurrencyGuard } diff --git a/internal/common/common.go b/internal/common/common.go index 6661e690e..1ec322898 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -135,6 +135,7 @@ var ( ErrNoCredentials = errors.New("no credential provided") ErrInternalFailure = errors.New("internal failure") ErrTransferAborted = errors.New("transfer aborted") + ErrShuttingDown = errors.New("the service is shutting down") errNoTransfer = errors.New("requested transfer not found") errTransferMismatch = errors.New("transfer mismatch") ) @@ -153,11 +154,13 @@ var ( ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} // the map key is the protocol, for each protocol we can have multiple rate limiters - rateLimiters map[string][]*rateLimiter + rateLimiters map[string][]*rateLimiter + isShuttingDown atomic.Bool ) // Initialize sets the common configuration func Initialize(c Configuration, isShared int) error { + isShuttingDown.Store(false) Config = c Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true) Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true) @@ -220,6 +223,67 @@ func Initialize(c Configuration, isShared int) error { return nil } +// CheckClosing returns an error if the service is closing +func CheckClosing() error { + if isShuttingDown.Load() { + return ErrShuttingDown + } + return nil +} + +// WaitForTransfers waits, for the specified grace time, for currently ongoing +// client-initiated transfer sessions to completes. +// A zero graceTime means no wait +func WaitForTransfers(graceTime int) { + if graceTime == 0 { + return + } + if isShuttingDown.Swap(true) { + return + } + + if activeHooks.Load() == 0 && getActiveConnections() == 0 { + return + } + + graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second) + ticker := time.NewTicker(3 * time.Second) + + for { + select { + case <-ticker.C: + hooks := activeHooks.Load() + logger.Info(logSender, "", "active hooks: %d", hooks) + if hooks == 0 && getActiveConnections() == 0 { + logger.Info(logSender, "", "no more active connections, graceful shutdown") + ticker.Stop() + graceTimer.Stop() + return + } + case <-graceTimer.C: + logger.Info(logSender, "", "grace time expired, hard shutdown") + ticker.Stop() + return + } + } +} + +// getActiveConnections returns the number of connections with active transfers +func getActiveConnections() int { + var activeConns int + + Connections.RLock() + for _, c := range Connections.connections { + if len(c.GetTransfers()) > 0 { + activeConns++ + } + } + Connections.RUnlock() + + logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns) + return activeConns +} + // LimitRate blocks until all the configured rate limiters // allow one event to happen. // It returns an error if the time to wait exceeds the max @@ -1051,30 +1115,34 @@ func (conns *ActiveConnections) GetClientConnections() int32 { return conns.clients.getTotal() } -// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded -// or a whitelist is defined and the specified ipAddr is not listed -func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool { +// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed +// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed +// or the service is shutting down +func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error { + if isShuttingDown.Load() { + return ErrShuttingDown + } if Config.whitelist != nil { if !Config.whitelist.isAllowed(ipAddr) { - return false + return ErrConnectionDenied } } if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { - return true + return nil } if Config.MaxPerHostConnections > 0 { if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections { - logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections) + logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections) AddDefenderEvent(ipAddr, HostEventLimitExceeded) - return false + return ErrConnectionDenied } } if Config.MaxTotalConnections > 0 { if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) { - logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections) - return false + logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections) + return ErrConnectionDenied } // on a single SFTP connection we could have multiple SFTP channels or commands @@ -1083,10 +1151,13 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool { conns.RLock() defer conns.RUnlock() - return len(conns.connections) < Config.MaxTotalConnections + if sess := len(conns.connections); sess >= Config.MaxTotalConnections { + logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections) + return ErrConnectionDenied + } } - return true + return nil } // GetStats returns stats for active connections diff --git a/internal/common/common_test.go b/internal/common/common_test.go index 550a3f18c..2c0d13755 100644 --- a/internal/common/common_test.go +++ b/internal/common/common_test.go @@ -497,10 +497,10 @@ func TestWhitelist(t *testing.T) { err = Initialize(Config, 0) assert.NoError(t, err) - assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.1")) - assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.3")) - assert.True(t, Connections.IsNewConnectionAllowed("10.8.7.3")) - assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.2")) + assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.1")) + assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.3")) + assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.7.3")) + assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.2")) wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3") wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24") @@ -508,14 +508,14 @@ func TestWhitelist(t *testing.T) { assert.NoError(t, err) err = os.WriteFile(wlFile, data, 0664) assert.NoError(t, err) - assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.3")) + assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.3")) err = Reload() assert.NoError(t, err) - assert.True(t, Connections.IsNewConnectionAllowed("10.8.8.3")) - assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.3")) - assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.2")) - assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.12")) + assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.8.3")) + assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.3")) + assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.2")) + assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.12")) Config = configCopy } @@ -550,12 +550,12 @@ func TestMaxConnections(t *testing.T) { Config.MaxPerHostConnections = 0 ipAddr := "192.168.7.8" - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) Config.MaxTotalConnections = 1 Config.MaxPerHostConnections = perHost - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, @@ -563,18 +563,18 @@ func TestMaxConnections(t *testing.T) { err := Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, Connections.GetStats(), 1) - assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) res := Connections.Close(fakeConn.GetID()) assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond) - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr) - assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) Connections.RemoveClientConnection(ipAddr) - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) Connections.RemoveClientConnection(ipAddr) Config.MaxTotalConnections = oldValue @@ -587,13 +587,13 @@ func TestMaxConnectionPerHost(t *testing.T) { ipAddr := "192.168.9.9" Connections.AddClientConnection(ipAddr) - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) Connections.AddClientConnection(ipAddr) - assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) Connections.AddClientConnection(ipAddr) - assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.Equal(t, int32(3), Connections.GetClientConnections()) Connections.RemoveClientConnection(ipAddr) @@ -697,7 +697,7 @@ func TestCloseConnection(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1")) + assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1")) err := Connections.Add(fakeConn) assert.NoError(t, err) assert.Len(t, Connections.GetStats(), 1) diff --git a/internal/common/connection.go b/internal/common/connection.go index 43b0b7f02..15f0a0c15 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -1362,6 +1362,9 @@ func (c *BaseConnection) GetGenericError(err error) error { if err == vfs.ErrStorageSizeUnavailable { return fmt.Errorf("%w: %v", sftp.ErrSSHFxOpUnsupported, err.Error()) } + if err == ErrShuttingDown { + return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error()) + } if err != nil { if e, ok := err.(*os.PathError); ok { c.Log(logger.LevelError, "generic path error: %+v", e) @@ -1373,7 +1376,7 @@ func (c *BaseConnection) GetGenericError(err error) error { return sftp.ErrSSHFxFailure default: if err == ErrPermissionDenied || err == ErrNotExist || err == ErrOpUnsupported || - err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable { + err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable || err == ErrShuttingDown { return err } return ErrGenericFailure @@ -1406,6 +1409,10 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin return nil, "", err } + if isShuttingDown.Load() { + return nil, "", c.GetFsError(fs, ErrShuttingDown) + } + fsPath, err := fs.ResolvePath(virtualPath) if err != nil { return nil, "", c.GetFsError(fs, err) diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go index b67cd9786..77c3fe876 100644 --- a/internal/common/connection_test.go +++ b/internal/common/connection_test.go @@ -385,6 +385,13 @@ func TestErrorsMapping(t *testing.T) { } else { assert.EqualError(t, err, ErrOpUnsupported.Error()) } + err = conn.GetFsError(fs, ErrShuttingDown) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + assert.Contains(t, err.Error(), ErrShuttingDown.Error()) + } else { + assert.EqualError(t, err, ErrShuttingDown.Error()) + } } } diff --git a/internal/common/eventmanager.go b/internal/common/eventmanager.go index 2304803bf..81f8176b4 100644 --- a/internal/common/eventmanager.go +++ b/internal/common/eventmanager.go @@ -101,10 +101,12 @@ type eventRulesContainer struct { } func (r *eventRulesContainer) addAsyncTask() { + activeHooks.Add(1) r.concurrencyGuard <- struct{}{} } func (r *eventRulesContainer) removeAsyncTask() { + activeHooks.Add(-1) <-r.concurrencyGuard } diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index 67ee6a13c..129aab4a6 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -596,6 +596,92 @@ func TestChtimesOpenHandle(t *testing.T) { assert.NoError(t, err) } +func TestWaitForConnections(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 128 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFileSize := int64(524288) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = common.CheckClosing() + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(1 * time.Second) + common.WaitForTransfers(10) + common.WaitForTransfers(0) + common.WaitForTransfers(10) + }() + + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + assert.NoError(t, err) + wg.Wait() + + err = common.CheckClosing() + assert.EqualError(t, err, common.ErrShuttingDown.Error()) + + _, err = client.Stat(testFileName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrShuttingDown.Error()) + } + } + + _, _, err = getSftpClient(user) + assert.Error(t, err) + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(1 * time.Second) + common.WaitForTransfers(1) + }() + + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + // we don't have an error here because the service won't really stop + assert.NoError(t, err) + wg.Wait() + } + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + common.WaitForTransfers(1) + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestCheckParentDirs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) @@ -6283,7 +6369,8 @@ func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMetho HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, - Auth: authMethods, + Auth: authMethods, + Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { @@ -6303,6 +6390,7 @@ func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, + Timeout: 5 * time.Second, } if user.Password != "" { config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} diff --git a/internal/ftpd/server.go b/internal/ftpd/server.go index 85968cee5..0b3198a8a 100644 --- a/internal/ftpd/server.go +++ b/internal/ftpd/server.go @@ -164,9 +164,9 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr) return "Access denied: banned client IP", common.ErrConnectionDenied } - if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolFTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) - return "Access denied", common.ErrConnectionDenied + if err := common.Connections.IsNewConnectionAllowed(ipAddr); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection not allowed from ip %q: %v", ipAddr, err) + return "Access denied", err } _, err := common.LimitRate(common.ProtocolFTP, ipAddr) if err != nil { diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index c1265cd00..11bbef7c8 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -10741,7 +10741,7 @@ func TestWebClientMaxConnections(t *testing.T) { setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -15136,7 +15136,7 @@ func TestWhitelist(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, webLoginPath, nil) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) req.RemoteAddr = "172.120.1.1" rr = executeRequest(req) @@ -15145,7 +15145,7 @@ func TestWhitelist(t *testing.T) { req.RemoteAddr = "172.120.1.3" rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) req.RemoteAddr = "192.8.7.1" rr = executeRequest(req) diff --git a/internal/httpd/server.go b/internal/httpd/server.go index cf76e4e1f..e895e300a 100644 --- a/internal/httpd/server.go +++ b/internal/httpd/server.go @@ -1030,9 +1030,9 @@ func (s *httpdServer) checkConnection(next http.Handler) http.Handler { common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) - if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) - s.sendForbiddenResponse(w, r, "connection not allowed from your ip") + if err := common.Connections.IsNewConnectionAllowed(ipAddr); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection not allowed from ip %q: %v", ipAddr, err) + s.sendForbiddenResponse(w, r, err.Error()) return } if common.IsBanned(ipAddr) { diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go index 6fd1b529a..081ad46ea 100644 --- a/internal/plugin/plugin.go +++ b/internal/plugin/plugin.go @@ -711,8 +711,10 @@ func (m *Manager) removeTask() { // Cleanup releases all the active plugins func (m *Manager) Cleanup() { + if m.closed.Swap(true) { + return + } logger.Debug(logSender, "", "cleanup") - m.closed.Store(true) close(m.done) m.notifLock.Lock() for _, n := range m.notifiers { diff --git a/internal/service/service.go b/internal/service/service.go index 78d374a3a..051e1b578 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -38,7 +38,8 @@ const ( ) var ( - chars = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") + chars = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") + graceTime int ) // Service defines the SFTPGo service @@ -90,8 +91,9 @@ func (s *Service) initLogger() { func (s *Service) Start(disableAWSInstallationCode bool) error { s.initLogger() logger.Info(logSender, "", "starting SFTPGo %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ - "log max age: %v log level: %v, log compress: %v, log utc time: %v, load data from: %#v", version.GetAsString(), s.ConfigDir, s.ConfigFile, - s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, s.LogCompress, s.LogUTCTime, s.LoadDataFrom) + "log max age: %v log level: %v, log compress: %v, log utc time: %v, load data from: %#v, grace time: %d secs", + version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, + s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime) // in portable mode we don't read configuration from file if s.PortableMode != 1 { err := config.LoadConfig(s.ConfigDir, s.ConfigFile) @@ -382,3 +384,8 @@ func (s *Service) restoreDump(dump *dataprovider.BackupData) error { } return nil } + +// SetGraceTime sets the grace time +func SetGraceTime(val int) { + graceTime = val +} diff --git a/internal/service/service_portable.go b/internal/service/service_portable.go index 5a01a0926..4a87d012d 100644 --- a/internal/service/service_portable.go +++ b/internal/service/service_portable.go @@ -29,11 +29,13 @@ import ( "github.com/grandcat/zeroconf" "github.com/sftpgo/sdk" + "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/version" @@ -238,6 +240,8 @@ func (s *Service) advertiseServices(advertiseService, advertiseCredentials bool) logger.InfoToConsole("unregistering multicast DNS WebDAV service") mDNSServiceDAV.Shutdown() } + plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) s.Stop() }() } diff --git a/internal/service/service_windows.go b/internal/service/service_windows.go index 228b6b422..1590bd0cf 100644 --- a/internal/service/service_windows.go +++ b/internal/service/service_windows.go @@ -129,6 +129,7 @@ loop: wasStopped <- true s.Service.Stop() plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) break loop case svc.ParamChange: logger.Debug(logSender, "", "Received reload request") diff --git a/internal/service/signals_unix.go b/internal/service/signals_unix.go index 5c387194f..95c6602e2 100644 --- a/internal/service/signals_unix.go +++ b/internal/service/signals_unix.go @@ -93,5 +93,6 @@ func handleSIGUSR1() { func handleInterrupt() { logger.Debug(logSender, "", "Received interrupt request") plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) os.Exit(0) } diff --git a/internal/service/signals_windows.go b/internal/service/signals_windows.go index 06bfe1d22..2b305f152 100644 --- a/internal/service/signals_windows.go +++ b/internal/service/signals_windows.go @@ -18,6 +18,7 @@ import ( "os" "os/signal" + "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/plugin" ) @@ -29,6 +30,7 @@ func registerSignals() { for range c { logger.Debug(logSender, "", "Received interrupt request") plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) os.Exit(0) } }() diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index 80a54f286..820a7a0d6 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -615,6 +615,15 @@ func TestSSHCommandErrors(t *testing.T) { err = os.Remove(tmpFile) assert.NoError(t, err) } + + common.WaitForTransfers(1) + _, err = cmd.getSystemCommand() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrShuttingDown.Error()) + } + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) } func TestCommandsWithExtensionsFilter(t *testing.T) { diff --git a/internal/sftpd/scp.go b/internal/sftpd/scp.go index 7006a6ab3..8b54dac63 100644 --- a/internal/sftpd/scp.go +++ b/internal/sftpd/scp.go @@ -496,18 +496,10 @@ func (c *scpCommand) handleDownload(filePath string) error { } var err error - fs, err := c.connection.User.GetFilesystemForPath(filePath, c.connection.ID) + fs, p, err := c.connection.GetFsAndResolvedPath(filePath) if err != nil { - c.connection.Log(logger.LevelError, "error downloading file %#v: %+v", filePath, err) - c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", filePath)) - return err - } - - p, err := fs.ResolvePath(filePath) - if err != nil { - err := fmt.Errorf("invalid file path %#v", filePath) - c.connection.Log(logger.LevelError, "error downloading file: %#v, invalid file path", filePath) - c.sendErrorMessage(fs, err) + c.connection.Log(logger.LevelError, "error downloading file %q: %+v", filePath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to download file %q: %w", filePath, err)) return err } diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index 0ec0fe5ba..28b7ad542 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -474,8 +474,8 @@ func canAcceptConnection(ip string) bool { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip) return false } - if !common.Connections.IsNewConnectionAllowed(ip) { - logger.Log(logger.LevelDebug, common.ProtocolSSH, "", fmt.Sprintf("connection not allowed from ip %#v", ip)) + if err := common.Connections.IsNewConnectionAllowed(ip); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err) return false } _, err := common.LimitRate(common.ProtocolSSH, ip) diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 33d26e0f2..a8531ddb6 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -7890,7 +7890,8 @@ func TestOpenUnhandledChannel(t *testing.T) { HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, - Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, + Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, + Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if assert.NoError(t, err) { @@ -10667,6 +10668,7 @@ func runSSHCommand(command string, user dataprovider.User, usePubKey bool) ([]by HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, + Timeout: 5 * time.Second, } if usePubKey { key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) @@ -10715,6 +10717,7 @@ func getSftpClientWithAddr(user dataprovider.User, usePubKey bool, addr string) HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, + Timeout: 5 * time.Second, } if usePubKey { signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) @@ -10760,6 +10763,7 @@ func getKeyboardInteractiveSftpClient(user dataprovider.User, answers []string) return answers, nil }), }, + Timeout: 5 * time.Second, } conn, err := ssh.Dial("tcp", sftpServerAddr, config) if err != nil { @@ -10779,7 +10783,8 @@ func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMetho HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, - Auth: authMethods, + Auth: authMethods, + Timeout: 5 * time.Second, } var err error var conn *ssh.Client diff --git a/internal/sftpd/ssh_cmd.go b/internal/sftpd/ssh_cmd.go index d6b67a470..8a3992cdd 100644 --- a/internal/sftpd/ssh_cmd.go +++ b/internal/sftpd/ssh_cmd.go @@ -482,6 +482,9 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) { fsPath: "", quotaCheckPath: "", } + if err := common.CheckClosing(); err != nil { + return command, err + } args := make([]string, len(c.args)) copy(args, c.args) var fsPath, quotaPath string diff --git a/internal/webdavd/server.go b/internal/webdavd/server.go index 3f0e1ab88..130c60570 100644 --- a/internal/webdavd/server.go +++ b/internal/webdavd/server.go @@ -165,9 +165,9 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { common.Connections.AddClientConnection(ipAddr) defer common.Connections.RemoveClientConnection(ipAddr) - if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) - http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable) + if err := common.Connections.IsNewConnectionAllowed(ipAddr); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection not allowed from ip %q: %v", ipAddr, err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) return } if common.IsBanned(ipAddr) {