Skip to content

Commit e80ca76

Browse files
committed
handle os interrupts properly
1 parent 684fd1f commit e80ca76

8 files changed

Lines changed: 68 additions & 47 deletions

File tree

cmd/s3backup/main.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"log"
88
"os"
9+
"os/signal"
910
"runtime/debug"
1011

1112
"github.com/alecthomas/kong"
@@ -122,8 +123,13 @@ type appCfg struct {
122123
}
123124

124125
func main() {
126+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
127+
defer stop()
128+
125129
description := kong.Description("S3 backup script in a single binary")
126130
app := kong.Parse(&appCfg{}, description, kong.HelpOptions{Compact: true})
131+
app.BindTo(ctx, (*context.Context)(nil))
132+
127133
if err := app.Run(); err != nil {
128134
log.Fatalln(err)
129135
}
@@ -146,7 +152,7 @@ func (c *versionCommand) Run() error {
146152
return nil
147153
}
148154

149-
func (c *putCommand) Run() error {
155+
func (c *putCommand) Run(ctx context.Context) error {
150156
app, err := newClient(c.awsFlags, c.SymKey, c.PemKey, c.SkipHash)
151157
if err != nil {
152158
return err
@@ -155,10 +161,10 @@ func (c *putCommand) Run() error {
155161
if err != nil {
156162
return err
157163
}
158-
return app.PutLocalFile(remote, local)
164+
return app.PutLocalFile(ctx, remote, local)
159165
}
160166

161-
func (c *getCommand) Run() error {
167+
func (c *getCommand) Run(ctx context.Context) error {
162168
app, err := newClient(c.awsFlags, c.SymKey, c.PemKey, c.SkipHash)
163169
if err != nil {
164170
return err
@@ -167,11 +173,11 @@ func (c *getCommand) Run() error {
167173
if err != nil {
168174
return err
169175
}
170-
return app.GetRemoteFile(remote, local)
176+
return app.GetRemoteFile(ctx, remote, local)
171177
}
172178

173-
func (c *vaultPutCommand) Run() error {
174-
cfg, err := vaultConfig(c.vaultFlags)
179+
func (c *vaultPutCommand) Run(ctx context.Context) error {
180+
cfg, err := vaultConfig(ctx, c.vaultFlags)
175181
if err != nil {
176182
return err
177183
}
@@ -189,11 +195,11 @@ func (c *vaultPutCommand) Run() error {
189195
awsFlags: awsConfig(cfg),
190196
encryptFlags: encryptFlags{SymKey: cfg.CipherKey, PemKey: pubKeyFile},
191197
}
192-
return cmd.Run()
198+
return cmd.Run(ctx)
193199
}
194200

195-
func (c vaultGetCommand) Run() error {
196-
cfg, err := vaultConfig(c.vaultFlags)
201+
func (c vaultGetCommand) Run(ctx context.Context) error {
202+
cfg, err := vaultConfig(ctx, c.vaultFlags)
197203
if err != nil {
198204
return err
199205
}
@@ -211,7 +217,7 @@ func (c vaultGetCommand) Run() error {
211217
awsFlags: awsConfig(cfg),
212218
decryptFlags: decryptFlags{SymKey: cfg.CipherKey, PemKey: privKeyFile},
213219
}
214-
return cmd.Run()
220+
return cmd.Run(ctx)
215221
}
216222

217223
func (c *genAesCommand) Run() error {
@@ -303,8 +309,8 @@ func checkPaths(inRemote, inLocal string) (outRemote string, outLocal string, er
303309
return
304310
}
305311

306-
func vaultConfig(f vaultFlags) (*config.Config, error) {
307-
return config.Lookup(context.Background(), config.VaultOpts{
312+
func vaultConfig(ctx context.Context, f vaultFlags) (*config.Config, error) {
313+
return config.Lookup(ctx, config.VaultOpts{
308314
Path: f.Path,
309315
Token: f.Token,
310316
RoleID: f.RoleID,

internal/client/client.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package client
22

33
import (
4+
"context"
45
"log"
56
"os"
67
)
@@ -13,15 +14,15 @@ type Client struct {
1314
Store Store
1415
}
1516

16-
func (c *Client) GetRemoteFile(remotePath, localPath string) error {
17+
func (c *Client) GetRemoteFile(ctx context.Context, remotePath, localPath string) error {
1718
tempFile := localPath
1819
if c.Cipher != nil {
1920
tempFile += tempFileSuffix
2021
defer remove(tempFile)
2122
}
2223

2324
log.Println("Downloading", remotePath, "to", tempFile)
24-
checksum, cerr := c.Store.DownloadFile(remotePath, tempFile)
25+
checksum, cerr := c.Store.DownloadFile(ctx, remotePath, tempFile)
2526
if cerr != nil {
2627
return cerr
2728
}
@@ -43,7 +44,7 @@ func (c *Client) GetRemoteFile(remotePath, localPath string) error {
4344
return nil
4445
}
4546

46-
func (c *Client) PutLocalFile(remotePath, localPath string) error {
47+
func (c *Client) PutLocalFile(ctx context.Context, remotePath, localPath string) error {
4748
tempFile := localPath
4849
if c.Cipher != nil {
4950
tempFile += tempFileSuffix
@@ -66,7 +67,7 @@ func (c *Client) PutLocalFile(remotePath, localPath string) error {
6667
}
6768

6869
log.Println("Uploading", tempFile, "as", remotePath)
69-
return c.Store.UploadFile(remotePath, tempFile, checksum)
70+
return c.Store.UploadFile(ctx, remotePath, tempFile, checksum)
7071
}
7172

7273
func remove(filePath string) {

internal/client/client_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package client
22

33
import (
4+
"context"
45
"testing"
56

67
"gotest.tools/v3/assert"
@@ -16,7 +17,7 @@ func TestGetRemoteFileWithoutDecryption(t *testing.T) {
1617
},
1718
}
1819
store := &StoreStub{
19-
DownloadFileFunc: func(remotePath string, localPath string) (string, error) {
20+
DownloadFileFunc: func(ctx context.Context, remotePath string, localPath string) (string, error) {
2021
assert.Check(t, is.Equal("s3://foo/bar.txt", remotePath))
2122
assert.Check(t, is.Equal("bar.txt", localPath))
2223
return "muahahaha", nil
@@ -26,7 +27,7 @@ func TestGetRemoteFileWithoutDecryption(t *testing.T) {
2627
Hash: hash,
2728
Store: store,
2829
}
29-
assert.NilError(t, c.GetRemoteFile("s3://foo/bar.txt", "bar.txt"))
30+
assert.NilError(t, c.GetRemoteFile(t.Context(), "s3://foo/bar.txt", "bar.txt"))
3031
}
3132

3233
func TestGetRemoteFileWithDecryption(t *testing.T) {
@@ -38,7 +39,7 @@ func TestGetRemoteFileWithDecryption(t *testing.T) {
3839
},
3940
}
4041
store := &StoreStub{
41-
DownloadFileFunc: func(remotePath string, localPath string) (string, error) {
42+
DownloadFileFunc: func(ctx context.Context, remotePath string, localPath string) (string, error) {
4243
assert.Check(t, is.Equal("s3://foo/bar.txt", remotePath))
4344
assert.Check(t, is.Equal("bar.txt.tmp", localPath))
4445
return "muahahaha", nil
@@ -56,7 +57,7 @@ func TestGetRemoteFileWithDecryption(t *testing.T) {
5657
Store: store,
5758
Cipher: cipher,
5859
}
59-
assert.NilError(t, c.GetRemoteFile("s3://foo/bar.txt", "bar.txt"))
60+
assert.NilError(t, c.GetRemoteFile(t.Context(), "s3://foo/bar.txt", "bar.txt"))
6061
}
6162

6263
func TestPutLocalFileWithoutEncryption(t *testing.T) {
@@ -67,7 +68,7 @@ func TestPutLocalFileWithoutEncryption(t *testing.T) {
6768
},
6869
}
6970
store := &StoreStub{
70-
UploadFileFunc: func(remotePath string, localPath string, checksum string) error {
71+
UploadFileFunc: func(ctx context.Context, remotePath string, localPath string, checksum string) error {
7172
assert.Check(t, is.Equal("s3://foo/bar.txt", remotePath))
7273
assert.Check(t, is.Equal("bar.txt", localPath))
7374
assert.Check(t, is.Equal("woahahaha", checksum))
@@ -78,7 +79,7 @@ func TestPutLocalFileWithoutEncryption(t *testing.T) {
7879
Hash: hash,
7980
Store: store,
8081
}
81-
assert.NilError(t, c.PutLocalFile("s3://foo/bar.txt", "bar.txt"))
82+
assert.NilError(t, c.PutLocalFile(t.Context(), "s3://foo/bar.txt", "bar.txt"))
8283
}
8384

8485
func TestPutLocalFileWithEncryption(t *testing.T) {
@@ -89,7 +90,7 @@ func TestPutLocalFileWithEncryption(t *testing.T) {
8990
},
9091
}
9192
store := &StoreStub{
92-
UploadFileFunc: func(remotePath string, localPath string, checksum string) error {
93+
UploadFileFunc: func(ctx context.Context, remotePath string, localPath string, checksum string) error {
9394
assert.Check(t, is.Equal("s3://foo/bar.txt", remotePath))
9495
assert.Check(t, is.Equal("bar.txt.tmp", localPath))
9596
assert.Check(t, is.Equal("woahahaha", checksum))
@@ -108,5 +109,5 @@ func TestPutLocalFileWithEncryption(t *testing.T) {
108109
Store: store,
109110
Cipher: cipher,
110111
}
111-
assert.NilError(t, c.PutLocalFile("s3://foo/bar.txt", "bar.txt"))
112+
assert.NilError(t, c.PutLocalFile(t.Context(), "s3://foo/bar.txt", "bar.txt"))
112113
}

internal/client/store.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package client
22

3+
import "context"
4+
35
//go:generate go run github.com/matryer/moq -out store_stub.go . Store:StoreStub
46

57
type Store interface {
6-
UploadFile(remotePath, localPath, checksum string) error
7-
DownloadFile(remotePath, localPath string) (checksum string, err error)
8+
UploadFile(ctx context.Context, remotePath, localPath, checksum string) error
9+
DownloadFile(ctx context.Context, remotePath, localPath string) (checksum string, err error)
810
}

internal/client/store/s3.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package store
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67

@@ -57,7 +58,7 @@ func NewS3(opts AwsOpts) (client.Store, error) {
5758
return &s3store{s3.New(awsSession)}, nil
5859
}
5960

60-
func (s *s3store) UploadFile(remotePath, localPath, checksum string) error {
61+
func (s *s3store) UploadFile(ctx context.Context, remotePath, localPath, checksum string) error {
6162
bucket, objectKey, err := splitRemotePath(remotePath)
6263
if err != nil {
6364
return err
@@ -80,14 +81,14 @@ func (s *s3store) UploadFile(remotePath, localPath, checksum string) error {
8081
checksumKey: aws.String(checksum),
8182
}
8283
}
83-
_, err = uploader.Upload(input)
84+
_, err = uploader.UploadWithContext(ctx, input)
8485
if err != nil {
8586
return fmt.Errorf("failed to upload file: %w", err)
8687
}
8788
return nil
8889
}
8990

90-
func (s *s3store) DownloadFile(remotePath, localPath string) (string, error) {
91+
func (s *s3store) DownloadFile(ctx context.Context, remotePath, localPath string) (string, error) {
9192
bucket, objectKey, err := splitRemotePath(remotePath)
9293
if err != nil {
9394
return "", err
@@ -103,7 +104,7 @@ func (s *s3store) DownloadFile(remotePath, localPath string) (string, error) {
103104
downloader := s3manager.NewDownloaderWithClient(s.api)
104105
req := &s3.GetObjectInput{Bucket: aws.String(bucket), Key: aws.String(objectKey)}
105106
opt := request.WithGetResponseHeader(fmt.Sprintf("x-amz-meta-%s", checksumKey), &checksum)
106-
_, err = downloader.Download(file, req, s3manager.WithDownloaderRequestOptions(opt))
107+
_, err = downloader.DownloadWithContext(ctx, file, req, s3manager.WithDownloaderRequestOptions(opt))
107108
if err != nil {
108109
return "", fmt.Errorf("download failed: %w", err)
109110
}

internal/client/store/s3_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ func TestRoundTripUploadDownload_withChecksum(t *testing.T) {
4141
_, err = impl.api.CreateBucket(&s3.CreateBucketInput{Bucket: aws.String("test-bucket")})
4242
assert.NilError(t, err, "failed to create bucket")
4343

44-
err = target.UploadFile("s3://test-bucket/test-file", uploadFile, "wibble")
44+
err = target.UploadFile(t.Context(), "s3://test-bucket/test-file", uploadFile, "wibble")
4545
assert.NilError(t, err, "failed to upload file")
4646

4747
downloadFile := uploadFile + ".download"
48-
checksum, err := target.DownloadFile("s3://test-bucket/test-file", downloadFile)
48+
checksum, err := target.DownloadFile(t.Context(), "s3://test-bucket/test-file", downloadFile)
4949
assert.NilError(t, err, "failed to download file")
5050
defer os.Remove(downloadFile)
5151

@@ -83,11 +83,11 @@ func TestRoundTripUploadDownload_withoutChecksum(t *testing.T) {
8383
_, err = impl.api.CreateBucket(&s3.CreateBucketInput{Bucket: aws.String("test-bucket")})
8484
assert.NilError(t, err, "failed to create bucket")
8585

86-
err = target.UploadFile("s3://test-bucket/test-file", uploadFile, "")
86+
err = target.UploadFile(t.Context(), "s3://test-bucket/test-file", uploadFile, "")
8787
assert.NilError(t, err, "failed to upload file")
8888

8989
downloadFile := uploadFile + ".download"
90-
checksum, err := target.DownloadFile("s3://test-bucket/test-file", downloadFile)
90+
checksum, err := target.DownloadFile(t.Context(), "s3://test-bucket/test-file", downloadFile)
9191
assert.NilError(t, err, "failed to download file")
9292
defer os.Remove(downloadFile)
9393

0 commit comments

Comments
 (0)