Skip to content

Commit a0dc3b6

Browse files
authored
refactor: add hook interface for extended functionality (#8585)
1 parent 9dcd06f commit a0dc3b6

File tree

14 files changed

+795
-198
lines changed

14 files changed

+795
-198
lines changed

integration/module_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
package integration
44

55
import (
6-
"github.com/aquasecurity/trivy/pkg/types"
76
"path/filepath"
87
"testing"
98

9+
"github.com/aquasecurity/trivy/pkg/extension"
1010
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
11-
"github.com/aquasecurity/trivy/pkg/scan/post"
11+
"github.com/aquasecurity/trivy/pkg/types"
1212
)
1313

1414
func TestModule(t *testing.T) {
@@ -52,7 +52,7 @@ func TestModule(t *testing.T) {
5252

5353
t.Cleanup(func() {
5454
analyzer.DeregisterAnalyzer("spring4shell")
55-
post.DeregisterPostScanner("spring4shell")
55+
extension.DeregisterHook("spring4shell")
5656
})
5757

5858
// Run Trivy

internal/hooktest/hook.go

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package hooktest
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/aquasecurity/trivy/pkg/extension"
9+
"github.com/aquasecurity/trivy/pkg/flag"
10+
"github.com/aquasecurity/trivy/pkg/types"
11+
)
12+
13+
type testHook struct{}
14+
15+
func (*testHook) Name() string {
16+
return "test"
17+
}
18+
19+
func (*testHook) Version() int {
20+
return 1
21+
}
22+
23+
// RunHook implementation
24+
func (*testHook) PreRun(ctx context.Context, opts flag.Options) error {
25+
if opts.GlobalOptions.ConfigFile == "bad-config" {
26+
return errors.New("bad pre-run")
27+
}
28+
return nil
29+
}
30+
31+
func (*testHook) PostRun(ctx context.Context, opts flag.Options) error {
32+
if opts.GlobalOptions.ConfigFile == "bad-config" {
33+
return errors.New("bad post-run")
34+
}
35+
return nil
36+
}
37+
38+
// ScanHook implementation
39+
func (*testHook) PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
40+
if target.Name == "bad-pre" {
41+
return errors.New("bad pre-scan")
42+
}
43+
target.Name += " (pre-scan)"
44+
return nil
45+
}
46+
47+
func (*testHook) PostScan(ctx context.Context, results types.Results) (types.Results, error) {
48+
for i, r := range results {
49+
if r.Target == "bad" {
50+
return nil, errors.New("bad")
51+
}
52+
for j := range r.Vulnerabilities {
53+
results[i].Vulnerabilities[j].References = []string{
54+
"https://example.com/post-scan",
55+
}
56+
}
57+
}
58+
return results, nil
59+
}
60+
61+
// ReportHook implementation
62+
func (*testHook) PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
63+
if report.ArtifactName == "bad-report" {
64+
return errors.New("bad pre-report")
65+
}
66+
67+
// Modify the report
68+
for i := range report.Results {
69+
for j := range report.Results[i].Vulnerabilities {
70+
report.Results[i].Vulnerabilities[j].Title = "Modified by pre-report hook"
71+
}
72+
}
73+
return nil
74+
}
75+
76+
func (*testHook) PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
77+
if report.ArtifactName == "bad-report" {
78+
return errors.New("bad post-report")
79+
}
80+
81+
// Modify the report
82+
for i := range report.Results {
83+
for j := range report.Results[i].Vulnerabilities {
84+
report.Results[i].Vulnerabilities[j].Description = "Modified by post-report hook"
85+
}
86+
}
87+
return nil
88+
}
89+
90+
func Init(t *testing.T) {
91+
h := &testHook{}
92+
extension.RegisterHook(h)
93+
t.Cleanup(func() {
94+
extension.DeregisterHook(h.Name())
95+
})
96+
}

pkg/commands/artifact/run.go

+31-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/aquasecurity/trivy/pkg/cache"
1616
"github.com/aquasecurity/trivy/pkg/commands/operation"
1717
"github.com/aquasecurity/trivy/pkg/db"
18+
"github.com/aquasecurity/trivy/pkg/extension"
1819
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
1920
"github.com/aquasecurity/trivy/pkg/fanal/artifact"
2021
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -277,7 +278,6 @@ func (r *runner) Report(ctx context.Context, opts flag.Options, report types.Rep
277278
if err := pkgReport.Write(ctx, report, opts); err != nil {
278279
return xerrors.Errorf("unable to write results: %w", err)
279280
}
280-
281281
return nil
282282
}
283283

@@ -375,12 +375,32 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err
375375
return v.SafeWriteConfigAs("trivy-default.yaml")
376376
}
377377

378+
// Call pre-run hooks
379+
if err := extension.PreRun(ctx, opts); err != nil {
380+
return xerrors.Errorf("pre run error: %w", err)
381+
}
382+
383+
// Run the application
384+
report, err := run(ctx, opts, targetKind)
385+
if err != nil {
386+
return xerrors.Errorf("run error: %w", err)
387+
}
388+
389+
// Call post-run hooks
390+
if err := extension.PostRun(ctx, opts); err != nil {
391+
return xerrors.Errorf("post run error: %w", err)
392+
}
393+
394+
return operation.Exit(opts, report.Results.Failed(), report.Metadata)
395+
}
396+
397+
func run(ctx context.Context, opts flag.Options, targetKind TargetKind) (types.Report, error) {
378398
r, err := NewRunner(ctx, opts)
379399
if err != nil {
380400
if errors.Is(err, SkipScan) {
381-
return nil
401+
return types.Report{}, nil
382402
}
383-
return xerrors.Errorf("init error: %w", err)
403+
return types.Report{}, xerrors.Errorf("init error: %w", err)
384404
}
385405
defer r.Close(ctx)
386406

@@ -395,24 +415,27 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err
395415

396416
scanFunction, exists := scans[targetKind]
397417
if !exists {
398-
return xerrors.Errorf("unknown target kind: %s", targetKind)
418+
return types.Report{}, xerrors.Errorf("unknown target kind: %s", targetKind)
399419
}
400420

421+
// 1. Scan the artifact
401422
report, err := scanFunction(ctx, opts)
402423
if err != nil {
403-
return xerrors.Errorf("%s scan error: %w", targetKind, err)
424+
return types.Report{}, xerrors.Errorf("%s scan error: %w", targetKind, err)
404425
}
405426

427+
// 2. Filter the results
406428
report, err = r.Filter(ctx, opts, report)
407429
if err != nil {
408-
return xerrors.Errorf("filter error: %w", err)
430+
return types.Report{}, xerrors.Errorf("filter error: %w", err)
409431
}
410432

433+
// 3. Report the results
411434
if err = r.Report(ctx, opts, report); err != nil {
412-
return xerrors.Errorf("report error: %w", err)
435+
return types.Report{}, xerrors.Errorf("report error: %w", err)
413436
}
414437

415-
return operation.Exit(opts, report.Results.Failed(), report.Metadata)
438+
return report, nil
416439
}
417440

418441
func disabledAnalyzers(opts flag.Options) []analyzer.Type {

pkg/extension/hook.go

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package extension
2+
3+
import (
4+
"context"
5+
"sort"
6+
7+
"github.com/samber/lo"
8+
"golang.org/x/xerrors"
9+
10+
"github.com/aquasecurity/trivy/pkg/flag"
11+
"github.com/aquasecurity/trivy/pkg/types"
12+
)
13+
14+
var hooks = make(map[string]Hook)
15+
16+
func RegisterHook(s Hook) {
17+
// Avoid duplication
18+
hooks[s.Name()] = s
19+
}
20+
21+
func DeregisterHook(name string) {
22+
delete(hooks, name)
23+
}
24+
25+
// Hook is an interface that defines the methods for a hook.
26+
type Hook interface {
27+
// Name returns the name of the extension.
28+
Name() string
29+
}
30+
31+
// RunHook is a extension that is called before and after all the processes.
32+
type RunHook interface {
33+
Hook
34+
35+
// PreRun is called before all the processes.
36+
PreRun(ctx context.Context, opts flag.Options) error
37+
38+
// PostRun is called after all the processes.
39+
PostRun(ctx context.Context, opts flag.Options) error
40+
}
41+
42+
// ScanHook is a extension that is called before and after the scan.
43+
type ScanHook interface {
44+
Hook
45+
46+
// PreScan is called before the scan. It can modify the scan target.
47+
// It may be called on the server side in client/server mode.
48+
PreScan(ctx context.Context, target *types.ScanTarget, opts types.ScanOptions) error
49+
50+
// PostScan is called after the scan. It can modify the results.
51+
// It may be called on the server side in client/server mode.
52+
// NOTE: Wasm modules cannot directly modify the passed results,
53+
// so it returns a copy of the results.
54+
PostScan(ctx context.Context, results types.Results) (types.Results, error)
55+
}
56+
57+
// ReportHook is a extension that is called before and after the report is written.
58+
type ReportHook interface {
59+
Hook
60+
61+
// PreReport is called before the report is written.
62+
// It can modify the report. It is called on the client side.
63+
PreReport(ctx context.Context, report *types.Report, opts flag.Options) error
64+
65+
// PostReport is called after the report is written.
66+
// It can modify the report. It is called on the client side.
67+
PostReport(ctx context.Context, report *types.Report, opts flag.Options) error
68+
}
69+
70+
func PreRun(ctx context.Context, opts flag.Options) error {
71+
for _, e := range Hooks() {
72+
h, ok := e.(RunHook)
73+
if !ok {
74+
continue
75+
}
76+
if err := h.PreRun(ctx, opts); err != nil {
77+
return xerrors.Errorf("%s pre run error: %w", e.Name(), err)
78+
}
79+
}
80+
return nil
81+
}
82+
83+
// PostRun is a hook that is called after all the processes.
84+
func PostRun(ctx context.Context, opts flag.Options) error {
85+
for _, e := range Hooks() {
86+
h, ok := e.(RunHook)
87+
if !ok {
88+
continue
89+
}
90+
if err := h.PostRun(ctx, opts); err != nil {
91+
return xerrors.Errorf("%s post run error: %w", e.Name(), err)
92+
}
93+
}
94+
return nil
95+
}
96+
97+
// PreScan is a hook that is called before the scan.
98+
func PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
99+
for _, e := range Hooks() {
100+
h, ok := e.(ScanHook)
101+
if !ok {
102+
continue
103+
}
104+
if err := h.PreScan(ctx, target, options); err != nil {
105+
return xerrors.Errorf("%s pre scan error: %w", e.Name(), err)
106+
}
107+
}
108+
return nil
109+
}
110+
111+
// PostScan is a hook that is called after the scan.
112+
func PostScan(ctx context.Context, results types.Results) (types.Results, error) {
113+
var err error
114+
for _, e := range Hooks() {
115+
h, ok := e.(ScanHook)
116+
if !ok {
117+
continue
118+
}
119+
results, err = h.PostScan(ctx, results)
120+
if err != nil {
121+
return nil, xerrors.Errorf("%s post scan error: %w", e.Name(), err)
122+
}
123+
}
124+
return results, nil
125+
}
126+
127+
// PreReport is a hook that is called before the report is written.
128+
func PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
129+
for _, e := range Hooks() {
130+
h, ok := e.(ReportHook)
131+
if !ok {
132+
continue
133+
}
134+
if err := h.PreReport(ctx, report, opts); err != nil {
135+
return xerrors.Errorf("%s pre report error: %w", e.Name(), err)
136+
}
137+
}
138+
return nil
139+
}
140+
141+
// PostReport is a hook that is called after the report is written.
142+
func PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
143+
for _, e := range Hooks() {
144+
h, ok := e.(ReportHook)
145+
if !ok {
146+
continue
147+
}
148+
if err := h.PostReport(ctx, report, opts); err != nil {
149+
return xerrors.Errorf("%s post report error: %w", e.Name(), err)
150+
}
151+
}
152+
return nil
153+
}
154+
155+
// Hooks returns the list of hooks.
156+
func Hooks() []Hook {
157+
hooks := lo.Values(hooks)
158+
sort.Slice(hooks, func(i, j int) bool {
159+
return hooks[i].Name() < hooks[j].Name()
160+
})
161+
return hooks
162+
}

0 commit comments

Comments
 (0)