Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sftp): make sure to delete last file when watch and delete_on_finish are enabled #3037

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 115 additions & 91 deletions internal/impl/sftp/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,61 @@ func newSFTPReaderFromParsed(conf *service.ParsedConfig, mgr *service.Resources)
}

func (s *sftpReader) Connect(ctx context.Context) (err error) {
file, nextPath, skip, err := s.seekNextPath(ctx)
if err != nil {
return err
}
if skip {
return nil
}

details := service.NewScannerSourceDetails()
details.SetName(nextPath)
if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) {
_ = s.pathProvider.Ack(ctx, nextPath, aErr)
if aErr != nil {
return nil
}
if s.deleteOnFinish {
s.scannerMut.Lock()
client := s.client
if client == nil {
if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil {
outErr = fmt.Errorf("obtain private client: %w", outErr)
}
defer func() {
_ = client.Close()
}()
}
if outErr == nil {
if outErr = client.Remove(nextPath); outErr != nil {
outErr = fmt.Errorf("remove %v: %w", nextPath, outErr)
}
}
s.scannerMut.Unlock()
}
return
}, details); err != nil {
_ = file.Close()
_ = s.pathProvider.Ack(ctx, nextPath, err)
return err
}

s.scannerMut.Lock()
s.currentPath = nextPath
s.scannerMut.Unlock()

s.log.Debugf("Consuming from file '%v'", nextPath)
return
}

func (s *sftpReader) initState(ctx context.Context) (client *sftp.Client, pathProvider pathProvider, skip bool, err error) {
s.scannerMut.Lock()
defer s.scannerMut.Unlock()

if s.scanner != nil {
return nil
skip = true
return
}

if s.client == nil {
Expand All @@ -191,13 +241,22 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) {
s.pathProvider = s.getFilePathProvider(ctx)
}

var nextPath string
var file *sftp.File
return s.client, s.pathProvider, false, nil
}

func (s *sftpReader) seekNextPath(ctx context.Context) (file *sftp.File, nextPath string, skip bool, err error) {
client, pathProvider, skip, err := s.initState(ctx)
if err != nil || skip {
return
}

for {
if nextPath, err = s.pathProvider.Next(ctx, s.client); err != nil {
if nextPath, err = pathProvider.Next(ctx, client); err != nil {
if errors.Is(err, sftp.ErrSshFxConnectionLost) {
_ = s.client.Close()
_ = client.Close()
s.scannerMut.Lock()
s.client = nil
s.scannerMut.Unlock()
return
}
if errors.Is(err, errEndOfPaths) {
Expand All @@ -206,62 +265,28 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) {
return
}

if file, err = s.client.Open(nextPath); err != nil {
if file, err = client.Open(nextPath); err != nil {
if errors.Is(err, sftp.ErrSshFxConnectionLost) {
_ = s.client.Close()
_ = client.Close()
s.scannerMut.Lock()
s.client = nil
s.scannerMut.Unlock()
}

s.log.With("path", nextPath, "err", err.Error()).Warn("Unable to open previously identified file")
if os.IsNotExist(err) {
// If we failed to open the file because it no longer exists
// then we can "ack" the path as we're done with it.
_ = s.pathProvider.Ack(ctx, nextPath, nil)
_ = pathProvider.Ack(ctx, nextPath, nil)
} else {
// Otherwise we "nack" it with the error as we'll want to
// reprocess it again later.
_ = s.pathProvider.Ack(ctx, nextPath, err)
_ = pathProvider.Ack(ctx, nextPath, err)
}
} else {
break
}
}

details := service.NewScannerSourceDetails()
details.SetName(nextPath)
if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) {
_ = s.pathProvider.Ack(ctx, nextPath, aErr)
if aErr != nil {
return nil
}
if s.deleteOnFinish {
s.scannerMut.Lock()
client := s.client
if client == nil {
if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil {
outErr = fmt.Errorf("obtain private client: %w", outErr)
}
defer func() {
_ = client.Close()
}()
}
if outErr == nil {
if outErr = client.Remove(nextPath); outErr != nil {
outErr = fmt.Errorf("remove %v: %w", nextPath, outErr)
}
}
s.scannerMut.Unlock()
return
}
return
}, details); err != nil {
_ = file.Close()
_ = s.pathProvider.Ack(ctx, nextPath, err)
return err
}
s.currentPath = nextPath

s.log.Debugf("Consuming from file '%v'", nextPath)
return
}

func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
Expand Down Expand Up @@ -297,9 +322,7 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi
part.MetaSetMut("sftp_path", currentPath)
}

return parts, func(ctx context.Context, res error) error {
return codecAckFn(ctx, res)
}, nil
return parts, codecAckFn, nil
}

func (s *sftpReader) Close(ctx context.Context) error {
Expand Down Expand Up @@ -363,61 +386,62 @@ type watcherPathProvider struct {
}

func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (string, error) {
if len(w.expandedPaths) > 0 {
nextPath := w.expandedPaths[0]
w.expandedPaths = w.expandedPaths[1:]
return nextPath, nil
}

if waitFor := time.Until(w.nextPoll); waitFor > 0 {
w.nextPoll = time.Now().Add(w.pollInterval)
select {
case <-time.After(waitFor):
case <-ctx.Done():
return "", ctx.Err()
for {
if len(w.expandedPaths) > 0 {
nextPath := w.expandedPaths[0]
w.expandedPaths = w.expandedPaths[1:]
return nextPath, nil
}
}

if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err != nil {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
continue
if waitFor := time.Until(w.nextPoll); w.nextPoll.IsZero() || waitFor > 0 {
w.nextPoll = time.Now().Add(w.pollInterval)
select {
case <-time.After(waitFor):
case <-ctx.Done():
return "", ctx.Err()
}
}

for _, path := range paths {
info, err := client.Stat(path)
if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err != nil {
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
continue
}

// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
for _, path := range paths {
info, err := client.Stat(path)
if err != nil {
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
continue
}

// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
}
}
}
}
}); cerr != nil {
return "", fmt.Errorf("error obtaining cache: %v", cerr)
}
}); cerr != nil {
return "", fmt.Errorf("error obtaining cache: %v", cerr)
w.followUpPoll = true
}
w.followUpPoll = true
return w.Next(ctx, client)
}

func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) (outErr error) {
Expand Down
Loading
Loading