@@ -6,13 +6,20 @@ package handlers
66
77import (
88 "context"
9+ "crypto/ecdsa"
10+ "crypto/elliptic"
11+ "crypto/rand"
12+ "crypto/sha256"
13+ "crypto/x509"
14+ "encoding/base64"
915 "testing"
1016
1117 "github.com/stretchr/testify/mock"
1218 "github.com/stretchr/testify/require"
1319
1420 "github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator"
1521 "github.com/elastic/elastic-agent/internal/pkg/agent/application/reexec"
22+ "github.com/elastic/elastic-agent/internal/pkg/agent/protection"
1623 "github.com/elastic/elastic-agent/internal/pkg/core/backoff"
1724 "github.com/elastic/elastic-agent/internal/pkg/fleetapi"
1825 "github.com/elastic/elastic-agent/pkg/core/logger/loggertest"
@@ -21,8 +28,10 @@ import (
2128
2229func TestActionMigratelHandler (t * testing.T ) {
2330 log , _ := loggertest .New ("" )
24- mockAgentInfo := mockinfo .NewAgent (t )
2531 t .Run ("wrong action type" , func (t * testing.T ) {
32+
33+ mockAgentInfo := mockinfo .NewAgent (t )
34+
2635 action := & fleetapi.ActionSettings {}
2736 ack := & fakeAcker {}
2837 ack .On ("Ack" , t .Context (), action ).Return (nil )
@@ -31,6 +40,7 @@ func TestActionMigratelHandler(t *testing.T) {
3140 coord := & fakeMigrateCoordinator {}
3241 coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
3342 coord .On ("ReExec" , mock .Anything , mock .Anything )
43+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : nil })
3444
3545 h := NewMigrate (log , mockAgentInfo , coord )
3646 require .NotNil (t , h .Handle (t .Context (), action , ack ))
@@ -39,6 +49,7 @@ func TestActionMigratelHandler(t *testing.T) {
3949 })
4050
4151 t .Run ("tamper protected agent" , func (t * testing.T ) {
52+ mockAgentInfo := mockinfo .NewAgent (t )
4253 action := & fleetapi.ActionMigrate {
4354 ActionType : "MIGRATE" ,
4455 }
@@ -51,6 +62,7 @@ func TestActionMigratelHandler(t *testing.T) {
5162 coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
5263 coord .On ("ReExec" , mock .Anything , mock .Anything )
5364 coord .On ("HasEndpoint" ).Return (true )
65+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : nil })
5466
5567 h := NewMigrate (log , mockAgentInfo , coord )
5668 h .tamperProtectionFn = func () bool { return true }
@@ -63,6 +75,8 @@ func TestActionMigratelHandler(t *testing.T) {
6375 })
6476
6577 t .Run ("action propagated to coordinator" , func (t * testing.T ) {
78+ mockAgentInfo := mockinfo .NewAgent (t )
79+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
6680 action := & fleetapi.ActionMigrate {}
6781
6882 ack := & fakeAcker {}
@@ -72,6 +86,56 @@ func TestActionMigratelHandler(t *testing.T) {
7286 coord := & fakeMigrateCoordinator {}
7387 coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
7488 coord .On ("ReExec" , mock .Anything , mock .Anything )
89+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : nil })
90+
91+ h := NewMigrate (log , mockAgentInfo , coord )
92+ h .tamperProtectionFn = func () bool { return false }
93+
94+ require .Nil (t , h .Handle (t .Context (), action , ack ))
95+ coord .AssertNumberOfCalls (t , "Migrate" , 1 )
96+
97+ // ack delegated to migrate coordinator
98+ ack .AssertNumberOfCalls (t , "Ack" , 0 )
99+ ack .AssertNumberOfCalls (t , "Commit" , 0 )
100+ coord .AssertCalled (t , "ReExec" , mock .Anything , mock .Anything )
101+ })
102+
103+ t .Run ("signature present" , func (t * testing.T ) {
104+ mockAgentInfo := mockinfo .NewAgent (t )
105+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
106+
107+ private , signatureValidationKey , err := genKeys ()
108+ require .NoError (t , err )
109+
110+ action := & fleetapi.ActionMigrate {
111+ ActionID : "123" ,
112+ ActionType : "MIGRATE" ,
113+ Data : fleetapi.ActionMigrateData {
114+ EnrollmentToken : "et-123" ,
115+ },
116+ }
117+
118+ actionBytes := []byte ("{\" action_id\" :\" 123\" ,\" agents\" :[\" agent-id\" ],\" type\" :\" MIGRATE\" ,\" data\" :{\" target_uri\" :\" \" ,\" enrollment_token\" :\" et-123\" ,\" settings\" :null}}" )
119+
120+ signature , err := sign (actionBytes , private )
121+ require .NoError (t , err )
122+
123+ base64Data := base64 .StdEncoding .EncodeToString (actionBytes )
124+ base64Signature := base64 .StdEncoding .EncodeToString (signature )
125+
126+ action .Signature = & fleetapi.Signed {
127+ Data : base64Data ,
128+ Signature : base64Signature ,
129+ }
130+
131+ ack := & fakeAcker {}
132+ ack .On ("Ack" , t .Context (), action ).Return (nil )
133+ ack .On ("Commit" , t .Context ()).Return (nil )
134+
135+ coord := & fakeMigrateCoordinator {}
136+ coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
137+ coord .On ("ReExec" , mock .Anything , mock .Anything )
138+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : signatureValidationKey })
75139
76140 h := NewMigrate (log , mockAgentInfo , coord )
77141 h .tamperProtectionFn = func () bool { return false }
@@ -85,7 +149,142 @@ func TestActionMigratelHandler(t *testing.T) {
85149 coord .AssertCalled (t , "ReExec" , mock .Anything , mock .Anything )
86150 })
87151
152+ t .Run ("signature present, action not signed" , func (t * testing.T ) {
153+ mockAgentInfo := mockinfo .NewAgent (t )
154+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
155+
156+ _ , signatureValidationKey , err := genKeys ()
157+ require .NoError (t , err )
158+
159+ action := & fleetapi.ActionMigrate {
160+ ActionID : "123" ,
161+ ActionType : "MIGRATE" ,
162+ Data : fleetapi.ActionMigrateData {
163+ EnrollmentToken : "et-123" ,
164+ },
165+ }
166+
167+ ack := & fakeAcker {}
168+ ack .On ("Ack" , t .Context (), action ).Return (nil )
169+ ack .On ("Commit" , t .Context ()).Return (nil )
170+
171+ coord := & fakeMigrateCoordinator {}
172+ coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
173+ coord .On ("ReExec" , mock .Anything , mock .Anything )
174+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : signatureValidationKey })
175+
176+ h := NewMigrate (log , mockAgentInfo , coord )
177+ h .tamperProtectionFn = func () bool { return false }
178+
179+ require .ErrorIs (t , h .Handle (t .Context (), action , ack ), protection .ErrNotSigned )
180+ coord .AssertNumberOfCalls (t , "Migrate" , 0 )
181+
182+ // ack delegated to migrate coordinator
183+ ack .AssertNumberOfCalls (t , "Ack" , 0 )
184+ ack .AssertNumberOfCalls (t , "Commit" , 0 )
185+ coord .AssertNumberOfCalls (t , "ReExec" , 0 )
186+ })
187+
188+ t .Run ("signature not present" , func (t * testing.T ) {
189+ mockAgentInfo := mockinfo .NewAgent (t )
190+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
191+
192+ private , _ , err := genKeys ()
193+ require .NoError (t , err )
194+
195+ action := & fleetapi.ActionMigrate {
196+ ActionID : "123" ,
197+ ActionType : "MIGRATE" ,
198+ Data : fleetapi.ActionMigrateData {
199+ EnrollmentToken : "et-123" ,
200+ },
201+ }
202+
203+ actionBytes := []byte ("{\" action_id\" :\" 123\" ,\" agents\" :[\" agent-id\" ],\" type\" :\" MIGRATE\" ,\" data\" :{\" target_uri\" :\" \" ,\" enrollment_token\" :\" et-123\" ,\" settings\" :null}}" )
204+
205+ signature , err := sign (actionBytes , private )
206+ require .NoError (t , err )
207+
208+ base64Data := base64 .StdEncoding .EncodeToString (actionBytes )
209+ base64Signature := base64 .StdEncoding .EncodeToString (signature )
210+
211+ action .Signature = & fleetapi.Signed {
212+ Data : base64Data ,
213+ Signature : base64Signature ,
214+ }
215+
216+ ack := & fakeAcker {}
217+ ack .On ("Ack" , t .Context (), action ).Return (nil )
218+ ack .On ("Commit" , t .Context ()).Return (nil )
219+
220+ coord := & fakeMigrateCoordinator {}
221+ coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
222+ coord .On ("ReExec" , mock .Anything , mock .Anything )
223+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : nil })
224+
225+ h := NewMigrate (log , mockAgentInfo , coord )
226+ h .tamperProtectionFn = func () bool { return false }
227+
228+ require .Nil (t , h .Handle (t .Context (), action , ack ))
229+ coord .AssertNumberOfCalls (t , "Migrate" , 1 )
230+
231+ // ack delegated to migrate coordinator
232+ ack .AssertNumberOfCalls (t , "Ack" , 0 )
233+ ack .AssertNumberOfCalls (t , "Commit" , 0 )
234+ coord .AssertCalled (t , "ReExec" , mock .Anything , mock .Anything )
235+ })
236+
237+ t .Run ("malformed signature" , func (t * testing.T ) {
238+ mockAgentInfo := mockinfo .NewAgent (t )
239+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
240+
241+ _ , signatureValidationKey , err := genKeys ()
242+ require .NoError (t , err )
243+
244+ private , _ , err := genKeys ()
245+ require .NoError (t , err )
246+
247+ action := & fleetapi.ActionMigrate {
248+ ActionID : "123" ,
249+ ActionType : "MIGRATE" ,
250+ Data : fleetapi.ActionMigrateData {
251+ EnrollmentToken : "et-123" ,
252+ },
253+ }
254+
255+ actionBytes := []byte ("{\" action_id\" :\" 123\" ,\" agents\" :[\" agent-id\" ],\" type\" :\" MIGRATE\" ,\" data\" :{\" target_uri\" :\" \" ,\" enrollment_token\" :\" et-123\" ,\" settings\" :null}}" )
256+
257+ signature , err := sign (actionBytes , private )
258+ require .NoError (t , err )
259+
260+ base64Data := base64 .StdEncoding .EncodeToString (actionBytes )
261+ base64Signature := base64 .StdEncoding .EncodeToString (signature )
262+
263+ action .Signature = & fleetapi.Signed {
264+ Data : base64Data ,
265+ Signature : base64Signature ,
266+ }
267+
268+ ack := & fakeAcker {}
269+ ack .On ("Ack" , t .Context (), action ).Return (nil )
270+ ack .On ("Commit" , t .Context ()).Return (nil )
271+
272+ coord := & fakeMigrateCoordinator {}
273+ coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (nil )
274+ coord .On ("ReExec" , mock .Anything , mock .Anything )
275+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : signatureValidationKey })
276+
277+ h := NewMigrate (log , mockAgentInfo , coord )
278+ h .tamperProtectionFn = func () bool { return false }
279+
280+ err = h .Handle (t .Context (), action , ack )
281+ require .ErrorIs (t , err , protection .ErrInvalidSignature )
282+ coord .AssertNumberOfCalls (t , "Migrate" , 0 )
283+ })
284+
88285 t .Run ("fleet server" , func (t * testing.T ) {
286+ mockAgentInfo := mockinfo .NewAgent (t )
287+ mockAgentInfo .On ("AgentID" ).Return ("agent-id" )
89288 action := & fleetapi.ActionMigrate {}
90289
91290 ack := & fakeAcker {}
@@ -95,6 +294,7 @@ func TestActionMigratelHandler(t *testing.T) {
95294 coord := & fakeMigrateCoordinator {}
96295 coord .On ("Migrate" , mock .Anything , mock .Anything ).Return (coordinator .ErrFleetServer )
97296 coord .On ("ReExec" , mock .Anything , mock .Anything )
297+ coord .On ("Protection" ).Return (protection.Config {SignatureValidationKey : nil })
98298
99299 h := NewMigrate (log , mockAgentInfo , coord )
100300 h .tamperProtectionFn = func () bool { return false }
@@ -126,3 +326,23 @@ func (f *fakeMigrateCoordinator) HasEndpoint() bool {
126326 args := f .Called ()
127327 return args .Bool (0 )
128328}
329+
330+ func (f * fakeMigrateCoordinator ) Protection () protection.Config {
331+ args := f .Called ()
332+ return args .Get (0 ).(protection.Config )
333+ }
334+
335+ func genKeys () (pk * ecdsa.PrivateKey , pubK []byte , err error ) {
336+ pk , err = ecdsa .GenerateKey (elliptic .P256 (), rand .Reader )
337+ if err != nil {
338+ return
339+ }
340+
341+ pubK , err = x509 .MarshalPKIXPublicKey (& pk .PublicKey )
342+ return pk , pubK , err
343+ }
344+
345+ func sign (data []byte , pk * ecdsa.PrivateKey ) ([]byte , error ) {
346+ hash := sha256 .Sum256 (data )
347+ return ecdsa .SignASN1 (rand .Reader , pk , hash [:])
348+ }
0 commit comments