Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.

package errors

import "errors"

var ErrInsufficientDiskSpace = errors.New("insufficient disk space")

func IsDiskSpaceError(err error) bool {
for _, osErr := range OS_DiskSpaceErrors {
if errors.Is(err, osErr) {
return true
}
}

return false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.

//go:build !windows

package errors

import "syscall"

var OS_DiskSpaceErrors = []error{
syscall.ENOSPC,
syscall.EDQUOT,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.

package errors

import (
goerrors "errors"
"fmt"
"testing"

"github.com/stretchr/testify/require"

agentErrors "github.com/elastic/elastic-agent/internal/pkg/agent/errors"
)

func TestIsDiskSpaceError(t *testing.T) {
for _, err := range OS_DiskSpaceErrors {
testCases := map[string]struct {
err error
want bool
}{
"os_error": {err: err, want: true},
"wrapped_os_error": {err: fmt.Errorf("wrapped: %w", err), want: true},
"joined_error": {err: goerrors.Join(err, goerrors.New("test")), want: true},
"new_error": {err: agentErrors.New(err, fmt.Errorf("test")), want: false},
}
for name, tc := range testCases {
t.Run(fmt.Sprintf("%s_%s", err.Error(), name), func(t *testing.T) {
require.Equal(t, tc.want, IsDiskSpaceError(tc.err))
})
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.

//go:build windows

package errors

import "golang.org/x/sys/windows"

var OS_DiskSpaceErrors = []error{
windows.ERROR_DISK_FULL,
windows.ERROR_HANDLE_DISK_FULL,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package fs

import (
"context"
goerrors "errors"
"fmt"
"io"
"os"
Expand All @@ -27,13 +28,20 @@ const (
type Downloader struct {
dropPath string
config *artifact.Config
// The following are abstractions for stdlib functions so that we can mock them in tests.
copy func(dst io.Writer, src io.Reader) (int64, error)
mkdirAll func(name string, perm os.FileMode) error
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
}

// NewDownloader creates and configures Elastic Downloader
func NewDownloader(config *artifact.Config) *Downloader {
return &Downloader{
config: config,
dropPath: getDropPath(config),
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
}
}

Expand Down Expand Up @@ -108,18 +116,18 @@ func (e *Downloader) downloadFile(filename, fullPath string) (string, error) {
defer sourceFile.Close()

if destinationDir := filepath.Dir(fullPath); destinationDir != "" && destinationDir != "." {
if err := os.MkdirAll(destinationDir, 0755); err != nil {
if err := e.mkdirAll(destinationDir, 0755); err != nil {
return "", err
}
}

destinationFile, err := os.OpenFile(fullPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, packagePermissions)
destinationFile, err := e.openFile(fullPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, packagePermissions)
if err != nil {
return "", errors.New(err, "creating package file failed", errors.TypeFilesystem, errors.M(errors.MetaKeyPath, fullPath))
return "", goerrors.Join(errors.New("creating package file failed", errors.TypeFilesystem, errors.M(errors.MetaKeyPath, fullPath)), err)
}
defer destinationFile.Close()

_, err = io.Copy(destinationFile, sourceFile)
_, err = e.copy(destinationFile, sourceFile)
if err != nil {
return "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ package fs
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/paths"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
agtversion "github.com/elastic/elastic-agent/pkg/version"
)

Expand Down Expand Up @@ -161,6 +164,9 @@ func TestDownloader_Download(t *testing.T) {
e := &Downloader{
dropPath: dropPath,
config: config,
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
}
got, err := e.Download(context.TODO(), tt.args.a, tt.args.version)
if !tt.wantErr(t, err, fmt.Sprintf("Download(%v, %v)", tt.args.a, tt.args.version)) {
Expand Down Expand Up @@ -282,6 +288,9 @@ func TestDownloader_DownloadAsc(t *testing.T) {
e := &Downloader{
dropPath: dropPath,
config: config,
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
}
got, err := e.DownloadAsc(context.TODO(), tt.args.a, tt.args.version)
if !tt.wantErr(t, err, fmt.Sprintf("DownloadAsc(%v, %v)", tt.args.a, tt.args.version)) {
Expand All @@ -291,3 +300,76 @@ func TestDownloader_DownloadAsc(t *testing.T) {
})
}
}

func TestDownloadDiskSpaceError(t *testing.T) {
testError := errors.New("test error")

testCases := map[string]struct {
mockStdlibFuncs func(downloader *Downloader)
expectedError error
}{
"when io.Copy runs into an error, the downloader should return the error and clean up the downloaded files": {
mockStdlibFuncs: func(downloader *Downloader) {
downloader.copy = func(dst io.Writer, src io.Reader) (int64, error) {
return 0, testError
}
},
expectedError: testError,
},
"when os.OpenFile runs into an error, the downloader should return the error and clean up the downloaded files": {
mockStdlibFuncs: func(downloader *Downloader) {
downloader.openFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
return nil, testError
}
},
expectedError: testError,
},
"when os.MkdirAll runs into an error, the downloader should return the error and clean up the downloaded files": {
mockStdlibFuncs: func(downloader *Downloader) {
downloader.mkdirAll = func(name string, perm os.FileMode) error {
return testError
}
},
expectedError: testError,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
baseDir := t.TempDir()
paths.SetTop(baseDir)
config := &artifact.Config{
DropPath: filepath.Join(baseDir, "drop"),
TargetDirectory: filepath.Join(baseDir, "target"),
}

err := os.MkdirAll(config.DropPath, 0o755)
require.NoError(t, err)

err = os.MkdirAll(config.TargetDirectory, 0o755)
require.NoError(t, err)

parsedVersion := agtversion.NewParsedSemVer(1, 2, 3, "", "")

artifactName, err := artifact.GetArtifactName(agentSpec, *parsedVersion, config.OS(), config.Arch())
require.NoError(t, err)

sourceArtifactPath := filepath.Join(config.DropPath, artifactName)
sourceArtifactHashPath := sourceArtifactPath + ".sha512"

err = os.WriteFile(sourceArtifactPath, []byte("test"), 0o666)
require.NoError(t, err, "failed to create source artifact file")

err = os.WriteFile(sourceArtifactHashPath, []byte("test"), 0o666)
require.NoError(t, err, "failed to create source artifact hash file")

downloader := NewDownloader(config)
tc.mockStdlibFuncs(downloader)
targetArtifactPath, err := downloader.Download(context.Background(), agentSpec, parsedVersion)

require.ErrorIs(t, err, tc.expectedError)

require.NoFileExists(t, targetArtifactPath)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package http

import (
"context"
goerrors "errors"
"fmt"
"io"
"net/http"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/elastic/elastic-agent-libs/transport/httpcommon"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download"
downloadErrors "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/errors"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details"
"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
"github.com/elastic/elastic-agent/pkg/core/logger"
Expand Down Expand Up @@ -49,6 +51,12 @@ type Downloader struct {
config *artifact.Config
client http.Client
upgradeDetails *details.Details
// The following are abstractions for stdlib functions so that we can mock them in tests.
copy func(dst io.Writer, src io.Reader) (int64, error)
mkdirAll func(name string, perm os.FileMode) error
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
// Abstraction for the disk space error check function so that we can mock it in tests.
isDiskSpaceErrorFunc func(err error) bool
}

// NewDownloader creates and configures Elastic Downloader
Expand All @@ -68,10 +76,14 @@ func NewDownloader(log *logger.Logger, config *artifact.Config, upgradeDetails *
// NewDownloaderWithClient creates Elastic Downloader with specific client used
func NewDownloaderWithClient(log *logger.Logger, config *artifact.Config, client http.Client, upgradeDetails *details.Details) *Downloader {
return &Downloader{
log: log,
config: config,
client: client,
upgradeDetails: upgradeDetails,
log: log,
config: config,
client: client,
upgradeDetails: upgradeDetails,
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
isDiskSpaceErrorFunc: downloadErrors.IsDiskSpaceError,
}
}

Expand Down Expand Up @@ -179,14 +191,14 @@ func (e *Downloader) downloadFile(ctx context.Context, artifactName, filename, f
}

if destinationDir := filepath.Dir(fullPath); destinationDir != "" && destinationDir != "." {
if err := os.MkdirAll(destinationDir, 0o755); err != nil {
if err := e.mkdirAll(destinationDir, 0o755); err != nil {
return "", err
}
}

destinationFile, err := os.OpenFile(fullPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, packagePermissions)
destinationFile, err := e.openFile(fullPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, packagePermissions)
if err != nil {
return "", errors.New(err, "creating package file failed", errors.TypeFilesystem, errors.M(errors.MetaKeyPath, fullPath))
return "", goerrors.Join(errors.New("creating package file failed", errors.TypeFilesystem, errors.M(errors.MetaKeyPath, fullPath)), err)
}
defer destinationFile.Close()

Expand All @@ -213,11 +225,18 @@ func (e *Downloader) downloadFile(ctx context.Context, artifactName, filename, f
detailsObserver := newDetailsProgressObserver(e.upgradeDetails)
dp := newDownloadProgressReporter(sourceURI, e.config.Timeout, fileSize, loggingObserver, detailsObserver)
dp.Report(ctx)
_, err = io.Copy(destinationFile, io.TeeReader(resp.Body, dp))

_, err = e.copy(destinationFile, io.TeeReader(resp.Body, dp))
if err != nil {
dp.ReportFailed(err)
// checking for disk space error here before passing it into the reporter
// so the details observer sets the state with clean error message
reportedErr := err
if e.isDiskSpaceErrorFunc(err) {
reportedErr = downloadErrors.ErrInsufficientDiskSpace
}
dp.ReportFailed(reportedErr)
// return path, file already exists and needs to be cleaned up
return fullPath, errors.New(err, "copying fetched package failed", errors.TypeNetwork, errors.M(errors.MetaKeyURI, sourceURI))
return fullPath, goerrors.Join(errors.New("copying fetched package failed", errors.TypeNetwork, errors.M(errors.MetaKeyURI, sourceURI)), err)
}
dp.ReportComplete()

Expand Down
Loading
Loading