diff --git a/changelog/fragments/1755268130-cleanup-downloads-directory-and-the-new-versioned-home-if-upgrade-fails.yaml b/changelog/fragments/1755268130-cleanup-downloads-directory-and-the-new-versioned-home-if-upgrade-fails.yaml new file mode 100644 index 00000000000..7a43b61d140 --- /dev/null +++ b/changelog/fragments/1755268130-cleanup-downloads-directory-and-the-new-versioned-home-if-upgrade-fails.yaml @@ -0,0 +1,32 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: enhancement + +# Change summary; a 80ish characters long description of the change. +summary: agent cleans up downloads directory and the new versioned home if upgrade fails + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +#description: + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: "elastic-agent" + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/elastic-agent/pull/9386 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/elastic-agent/issues/5235 diff --git a/internal/pkg/agent/application/upgrade/artifact/download/verifier.go b/internal/pkg/agent/application/upgrade/artifact/download/verifier.go index 67d16076f4e..2faf9e762a8 100644 --- a/internal/pkg/agent/application/upgrade/artifact/download/verifier.go +++ b/internal/pkg/agent/application/upgrade/artifact/download/verifier.go @@ -106,12 +106,12 @@ func VerifySHA512HashWithCleanup(log infoWarnLogger, filename string) error { } } else if err != nil && !errors.Is(err, os.ErrNotExist) { // it's not a simple hash mismatch, probably something is wrong with the hash file - hashFileName := getHashFileName(filename) + hashFileName := AddHashExtension(filename) hashFileBytes, readErr := os.ReadFile(hashFileName) if readErr != nil { - log.Warnf("error verifying the package using hash file %q, unable do read contents for logging: %v", getHashFileName(filename), readErr) + log.Warnf("error verifying the package using hash file %q, unable do read contents for logging: %v", AddHashExtension(filename), readErr) } else { - log.Warnf("error verifying the package using hash file %q, contents: %q", getHashFileName(filename), string(hashFileBytes)) + log.Warnf("error verifying the package using hash file %q, contents: %q", AddHashExtension(filename), string(hashFileBytes)) } } @@ -121,12 +121,12 @@ func VerifySHA512HashWithCleanup(log infoWarnLogger, filename string) error { return nil } -func getHashFileName(filename string) string { +func AddHashExtension(file string) string { const hashFileExt = ".sha512" - if strings.HasSuffix(filename, hashFileExt) { - return filename + if strings.HasSuffix(file, hashFileExt) { + return file } - return filename + hashFileExt + return file + hashFileExt } // VerifySHA512Hash checks that a sidecar file containing a sha512 checksum @@ -134,7 +134,7 @@ func getHashFileName(filename string) string { // the file. It returns an error if validation fails. func VerifySHA512Hash(filename string) error { hasher := sha512.New() - checksumFileName := getHashFileName(filename) + checksumFileName := AddHashExtension(filename) return VerifyChecksum(hasher, filename, checksumFileName) } diff --git a/internal/pkg/agent/application/upgrade/step_download.go b/internal/pkg/agent/application/upgrade/step_download.go index 8b8966b3a20..f94aa8fdedd 100644 --- a/internal/pkg/agent/application/upgrade/step_download.go +++ b/internal/pkg/agent/application/upgrade/step_download.go @@ -41,16 +41,21 @@ type downloaderFactory func(*agtversion.ParsedSemVer, *logger.Logger, *artifact. type downloader func(context.Context, downloaderFactory, *agtversion.ParsedSemVer, *artifact.Config, *details.Details) (string, error) +// abstraction for testability for newVerifier +type verifierFactory func(*agtversion.ParsedSemVer, *logger.Logger, *artifact.Config) (download.Verifier, error) + type artifactDownloader struct { log *logger.Logger settings *artifact.Config fleetServerURI string + newVerifier verifierFactory } func newArtifactDownloader(settings *artifact.Config, log *logger.Logger) *artifactDownloader { return &artifactDownloader{ - log: log, - settings: settings, + log: log, + settings: settings, + newVerifier: newVerifier, } } @@ -123,19 +128,21 @@ func (a *artifactDownloader) downloadArtifact(ctx context.Context, parsedVersion return "", fmt.Errorf("failed download of agent binary: %w", err) } + // If there are errors in the following steps, we return the path so that we + // can cleanup the downloaded files. if skipVerifyOverride { return path, nil } if verifier == nil { - verifier, err = newVerifier(parsedVersion, a.log, &settings) + verifier, err = a.newVerifier(parsedVersion, a.log, &settings) if err != nil { - return "", errors.New(err, "initiating verifier") + return path, errors.New(err, "initiating verifier") } } if err := verifier.Verify(ctx, agentArtifact, *parsedVersion, skipDefaultPgp, pgpBytes...); err != nil { - return "", errors.New(err, "failed verification of agent binary") + return path, errors.New(err, "failed verification of agent binary") } return path, nil } diff --git a/internal/pkg/agent/application/upgrade/step_download_test.go b/internal/pkg/agent/application/upgrade/step_download_test.go index 1cd8cef6f56..33596b3de98 100644 --- a/internal/pkg/agent/application/upgrade/step_download_test.go +++ b/internal/pkg/agent/application/upgrade/step_download_test.go @@ -8,6 +8,8 @@ import ( "context" "encoding/json" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" "time" @@ -15,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent-libs/transport/httpcommon" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download" downloadErrors "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/errors" @@ -303,6 +306,86 @@ func TestDownloadWithRetries(t *testing.T) { }) } +type mockVerifier struct { + called bool + returnError error +} + +func (mv *mockVerifier) Name() string { + return "" +} + +func (mv *mockVerifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error { + mv.called = true + return mv.returnError +} + +func TestDownloadArtifact(t *testing.T) { + testLogger, _ := loggertest.New("TestDownloadArtifact") + tempConfig := &artifact.Config{} // used only to get os and arch, runtime.GOARCH returns amd64 which is not a valid arch when used in GetArtifactName + + parsedVersion, err := agtversion.ParseVersion("8.9.0") + require.NoError(t, err) + + upgradeDeatils := details.NewDetails(parsedVersion.String(), details.StateRequested, "") + + mockContent := []byte("mock content") + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(mockContent) + require.NoError(t, err) + })) + defer testServer.Close() + + testError := errors.New("test error") + + type testCase struct { + mockNewVerifierFactory verifierFactory + expectedError error + } + + testCases := map[string]testCase{ + "should return path if verifier constructor fails": { + mockNewVerifierFactory: func(version *agtversion.ParsedSemVer, log *logger.Logger, settings *artifact.Config) (download.Verifier, error) { + return nil, testError + }, + expectedError: testError, + }, + "should return path if verifier fails": { + mockNewVerifierFactory: func(version *agtversion.ParsedSemVer, log *logger.Logger, settings *artifact.Config) (download.Verifier, error) { + return &mockVerifier{returnError: testError}, nil + }, + expectedError: testError, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + paths.SetTop(t.TempDir()) + + artifactPath, err := artifact.GetArtifactPath(agentArtifact, *parsedVersion, tempConfig.OS(), tempConfig.Arch(), paths.Downloads()) + require.NoError(t, err) + + settings := artifact.Config{ + RetrySleepInitDuration: 20 * time.Millisecond, + HTTPTransportSettings: httpcommon.HTTPTransportSettings{ + Timeout: 2 * time.Second, + }, + SourceURI: testServer.URL, + TargetDirectory: paths.Downloads(), + } + + a := newArtifactDownloader(&settings, testLogger) + a.newVerifier = tc.mockNewVerifierFactory + + path, err := a.downloadArtifact(t.Context(), parsedVersion, testServer.URL, upgradeDeatils, false, true) + require.ErrorIs(t, err, tc.expectedError) + require.Equal(t, artifactPath, path) + }) + } +} + // mockUpgradeDetails returns a *details.Details value that has an observer registered on it for inspecting // certain properties of the object being set and unset. It also returns: // - a *time.Time value, which will be not nil if Metadata.RetryUntil is set on the mock value, diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index 323c2301c28..fbc2b38f414 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -24,6 +24,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/application/reexec" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download" upgradeErrors "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download/errors" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/details" "github.com/elastic/elastic-agent/internal/pkg/agent/configuration" @@ -276,7 +277,7 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s } u.log.Infow("Upgrading agent", "version", version, "source_uri", sourceURI) - + cleanupPaths := []string{} defer func() { if err != nil { // Add the disk space error to the error chain if it is a disk space error @@ -284,6 +285,15 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s if u.isDiskSpaceErrorFunc(err) { err = goerrors.Join(err, upgradeErrors.ErrInsufficientDiskSpace) } + // If there is an error, we need to clean up downloads and any + // extracted agent files. + for _, path := range cleanupPaths { + rmErr := os.RemoveAll(path) + if rmErr != nil { + u.log.Errorw("error removing path during upgrade cleanup", "error.message", rmErr, "path", path) + err = goerrors.Join(err, rmErr) + } + } } }() @@ -328,6 +338,15 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s } archivePath, err := u.artifactDownloader.downloadArtifact(ctx, parsedVersion, sourceURI, det, skipVerifyOverride, skipDefaultPgp, pgpBytes...) + + // If the artifactPath is not empty, then the artifact was downloaded. + // There may still be an error in the download process, so we need to add + // the archive and hash path to the cleanup slice. + if archivePath != "" { + archiveHashPath := download.AddHashExtension(archivePath) + cleanupPaths = append(cleanupPaths, archivePath, archiveHashPath) + } + if err != nil { // Run the same pre-upgrade cleanup task to get rid of any newly downloaded files // This may have an issue if users are upgrading to the same version number. @@ -363,6 +382,20 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s } u.log.Debugf("detected used flavor: %q", detectedFlavor) unpackRes, err := u.unpacker.unpack(version, archivePath, paths.Data(), detectedFlavor) + + // If VersionedHome is empty then unpack has not started unpacking the + // archive yet. There's nothing to clean up. Return the error. + if unpackRes.VersionedHome == "" { + return nil, goerrors.Join(err, fmt.Errorf("versionedhome is empty: %v", unpackRes)) + } + + // If VersionedHome is not empty, it means that the unpack function has + // started extracting the archive. It may have failed while extracting. + // Setup newHome to be cleanedup. + newHome := filepath.Join(paths.Top(), unpackRes.VersionedHome) + + cleanupPaths = append(cleanupPaths, newHome) + if err != nil { return nil, err } @@ -372,12 +405,6 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, rollback bool, s return nil, errors.New("unknown hash") } - if unpackRes.VersionedHome == "" { - return nil, fmt.Errorf("versionedhome is empty: %v", unpackRes) - } - - newHome := filepath.Join(paths.Top(), unpackRes.VersionedHome) - if err := u.copyActionStore(u.log, newHome); err != nil { return nil, fmt.Errorf("failed to copy action store: %w", err) } diff --git a/internal/pkg/agent/application/upgrade/upgrade_test.go b/internal/pkg/agent/application/upgrade/upgrade_test.go index 42e62f0c35b..8f51537b128 100644 --- a/internal/pkg/agent/application/upgrade/upgrade_test.go +++ b/internal/pkg/agent/application/upgrade/upgrade_test.go @@ -1313,12 +1313,13 @@ func TestManualRollback(t *testing.T) { } type mockArtifactDownloader struct { - returnError error - fleetServerURI string + returnError error + returnArchivePath string + fleetServerURI string } func (m *mockArtifactDownloader) downloadArtifact(ctx context.Context, parsedVersion *agtversion.ParsedSemVer, sourceURI string, upgradeDetails *details.Details, skipVerifyOverride, skipDefaultPgp bool, pgpBytes ...string) (_ string, err error) { - return "", m.returnError + return m.returnArchivePath, m.returnError } func (m *mockArtifactDownloader) withFleetServerURI(fleetServerURI string) { @@ -1343,39 +1344,49 @@ func (m *mockUnpacker) unpack(version, archivePath, dataDir string, flavor strin func TestUpgradeErrorHandling(t *testing.T) { log, _ := loggertest.New("test") testError := errors.New("test error") - type upgraderMocker func(upgrader *Upgrader) + + type upgraderMocker func(upgrader *Upgrader, archivePath string, versionedHome string) type testCase struct { - isDiskSpaceErrorResult bool - expectedError error - upgraderMocker upgraderMocker + isDiskSpaceErrorResult bool + expectedError error + upgraderMocker upgraderMocker + checkArchiveCleanup bool + checkVersionedHomeCleanup bool } testCases := map[string]testCase{ - "should return error if downloadArtifact fails": { + "should return error and cleanup downloaded archive if downloadArtifact fails after download is complete": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { upgrader.artifactDownloader = &mockArtifactDownloader{ - returnError: testError, + returnError: testError, + returnArchivePath: archivePath, } }, + checkArchiveCleanup: true, }, "should return error if getPackageMetadata fails": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.unpacker = &mockUnpacker{ returnPackageMetadataError: testError, } }, + checkArchiveCleanup: true, }, - "should return error if unpack fails": { + "should return error and cleanup downloaded archive if unpack fails before extracting": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { return agentVersion{ version: upgradeVersion, @@ -1391,12 +1402,44 @@ func TestUpgradeErrorHandling(t *testing.T) { returnUnpackError: testError, } }, + checkArchiveCleanup: true, }, - "should return error if copyActionStore fails": { + "should return error and cleanup downloaded archive if unpack fails after extracting": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } + upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { + return agentVersion{ + version: upgradeVersion, + snapshot: false, + hash: metadata.hash, + } + } + upgrader.unpacker = &mockUnpacker{ + returnPackageMetadata: packageMetadata{ + manifest: &v1.PackageManifest{}, + hash: "hash", + }, + returnUnpackError: testError, + returnUnpackResult: UnpackResult{ + Hash: "hash", + VersionedHome: versionedHome, + }, + } + }, + checkArchiveCleanup: true, + checkVersionedHomeCleanup: true, + }, + "should return error and cleanup downloaded artifact and extracted archive if copyActionStore fails": { + isDiskSpaceErrorResult: false, + expectedError: testError, + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { return agentVersion{ version: upgradeVersion, @@ -1411,20 +1454,24 @@ func TestUpgradeErrorHandling(t *testing.T) { }, returnUnpackResult: UnpackResult{ Hash: "hash", - VersionedHome: "versionedHome", + VersionedHome: versionedHome, }, } upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { return testError } }, + checkArchiveCleanup: true, + checkVersionedHomeCleanup: true, }, - "should return error if copyRunDirectory fails": { + "should return error and cleanup downloaded artifact and extracted archive if copyRunDirectory fails": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { return agentVersion{ version: upgradeVersion, @@ -1439,7 +1486,7 @@ func TestUpgradeErrorHandling(t *testing.T) { }, returnUnpackResult: UnpackResult{ Hash: "hash", - VersionedHome: "versionedHome", + VersionedHome: versionedHome, }, } upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { @@ -1449,12 +1496,16 @@ func TestUpgradeErrorHandling(t *testing.T) { return testError } }, + checkArchiveCleanup: true, + checkVersionedHomeCleanup: true, }, - "should return error if changeSymlink fails": { + "should return error and cleanup downloaded artifact and extracted archive if changeSymlink fails": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { return agentVersion{ version: upgradeVersion, @@ -1469,7 +1520,7 @@ func TestUpgradeErrorHandling(t *testing.T) { }, returnUnpackResult: UnpackResult{ Hash: "hash", - VersionedHome: "versionedHome", + VersionedHome: versionedHome, }, } upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { @@ -1485,12 +1536,16 @@ func TestUpgradeErrorHandling(t *testing.T) { return testError } }, + checkArchiveCleanup: true, + checkVersionedHomeCleanup: true, }, - "should return error if markUpgrade fails": { + "should return error and cleanup downloaded artifact and extracted archive if markUpgrade fails": { isDiskSpaceErrorResult: false, expectedError: testError, - upgraderMocker: func(upgrader *Upgrader) { - upgrader.artifactDownloader = &mockArtifactDownloader{} + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnArchivePath: archivePath, + } upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { return agentVersion{ version: upgradeVersion, @@ -1505,7 +1560,7 @@ func TestUpgradeErrorHandling(t *testing.T) { }, returnUnpackResult: UnpackResult{ Hash: "hash", - VersionedHome: "versionedHome", + VersionedHome: versionedHome, }, } upgrader.copyActionStore = func(log *logger.Logger, newHome string) error { @@ -1524,13 +1579,15 @@ func TestUpgradeErrorHandling(t *testing.T) { return testError } }, + checkArchiveCleanup: true, + checkVersionedHomeCleanup: true, }, "should add disk space error to the error chain if downloadArtifact fails with disk space error": { isDiskSpaceErrorResult: true, expectedError: upgradeErrors.ErrInsufficientDiskSpace, - upgraderMocker: func(upgrader *Upgrader) { + upgraderMocker: func(upgrader *Upgrader, archivePath string, versionedHome string) { upgrader.artifactDownloader = &mockArtifactDownloader{ - returnError: upgradeErrors.ErrInsufficientDiskSpace, + returnError: testError, } }, }, @@ -1541,11 +1598,21 @@ func TestUpgradeErrorHandling(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { + baseDir := t.TempDir() + paths.SetTop(baseDir) + mockWatcherHelper := NewMockWatcherHelper(t) upgrader, err := NewUpgrader(log, &artifact.Config{}, nil, mockAgentInfo, mockWatcherHelper) require.NoError(t, err) - tc.upgraderMocker(upgrader) + tc.upgraderMocker(upgrader, filepath.Join(baseDir, "mockArchive"), "versionedHome") + + // Create the test files for all the cases + err = os.WriteFile(filepath.Join(baseDir, "mockArchive"), []byte("test"), 0o600) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(baseDir, "versionedHome"), []byte("test"), 0o600) + require.NoError(t, err) upgrader.isDiskSpaceErrorFunc = func(err error) bool { return tc.isDiskSpaceErrorResult @@ -1553,28 +1620,22 @@ func TestUpgradeErrorHandling(t *testing.T) { _, err = upgrader.Upgrade(context.Background(), "9.0.0", false, "", nil, details.NewDetails("9.0.0", details.StateRequested, "test"), true, true) require.ErrorIs(t, err, tc.expectedError) - }) - } -} -type mockSender struct{} - -func (m *mockSender) Send(ctx context.Context, method, path string, params url.Values, headers http.Header, body io.Reader) (*http.Response, error) { - return nil, nil -} + // If the downloaded archive needs to be cleaned up assert that it is indeed cleaned up, if not assert that it still exists. The downloaded archive is a mock file that is created for all tests cases. + if tc.checkArchiveCleanup { + require.NoFileExists(t, filepath.Join(baseDir, "mockArchive")) + } else { + require.FileExists(t, filepath.Join(baseDir, "mockArchive")) + } -func (m *mockSender) URI() string { - return "mockURI" -} -func TestSetClient(t *testing.T) { - log, _ := loggertest.New("test") - upgrader := &Upgrader{ - log: log, - artifactDownloader: &mockArtifactDownloader{}, + // If the extracted agent needs to be cleaned up assert that it is indeed cleaned up, if not assert that it still exists. Versioned home is a mock file that is created for all test cases. + if tc.checkVersionedHomeCleanup { + require.NoFileExists(t, filepath.Join(baseDir, "versionedHome")) + } else { + require.FileExists(t, filepath.Join(baseDir, "versionedHome")) + } + }) } - - upgrader.SetClient(&mockSender{}) - require.Equal(t, "mockURI", upgrader.artifactDownloader.(*mockArtifactDownloader).fleetServerURI) } func TestCopyActionStore(t *testing.T) { @@ -1761,3 +1822,23 @@ func TestCopyRunDirectory(t *testing.T) { }) } } + +type mockSender struct{} + +func (m *mockSender) Send(ctx context.Context, method, path string, params url.Values, headers http.Header, body io.Reader) (*http.Response, error) { + return nil, nil +} + +func (m *mockSender) URI() string { + return "mockURI" +} +func TestSetClient(t *testing.T) { + log, _ := loggertest.New("test") + upgrader := &Upgrader{ + log: log, + artifactDownloader: &mockArtifactDownloader{}, + } + + upgrader.SetClient(&mockSender{}) + require.Equal(t, "mockURI", upgrader.artifactDownloader.(*mockArtifactDownloader).fleetServerURI) +}