Skip to content

Commit 10ae7e0

Browse files
author
Lei jin
committed
Fix concurrency issue of Component Registery o Azure Session
Add the test cases for component registeration
1 parent 8d99b1b commit 10ae7e0

File tree

2 files changed

+366
-16
lines changed

2 files changed

+366
-16
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
// ----------------------------------------------------------------------------------
2+
//
3+
// Copyright Microsoft Corporation
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
// ----------------------------------------------------------------------------------
14+
using Microsoft.Azure.Commands.Common.Authentication;
15+
using Microsoft.Azure.Commands.Common.Authentication.Abstractions;
16+
using Microsoft.Azure.Commands.Common.Authentication.Abstractions.Models;
17+
18+
using System;
19+
using System.Collections.Concurrent;
20+
using System.Collections.Generic;
21+
using System.Diagnostics;
22+
using System.Linq;
23+
using System.Threading;
24+
using System.Threading.Tasks;
25+
26+
using Xunit;
27+
28+
namespace Authentication.Abstractions.Test
29+
{
30+
public class AzureSessionTest : IDisposable
31+
32+
{
33+
// Concrete implementation of AzureSession for testing
34+
private class TestAzureSession : AzureSession
35+
{
36+
public override TraceLevel AuthenticationLegacyTraceLevel { get; set; }
37+
public override TraceListenerCollection AuthenticationTraceListeners => Trace.Listeners;
38+
public override SourceLevels AuthenticationTraceSourceLevel { get; set; }
39+
}
40+
41+
private IAzureSession oldSession = null;
42+
43+
public AzureSessionTest()
44+
{
45+
try
46+
{
47+
IAzureSession oldSession = AzureSession.Instance;
48+
49+
}
50+
catch (Exception)
51+
{
52+
}
53+
AzureSession.Initialize(() => new TestAzureSession(), true);
54+
}
55+
56+
public void Dispose()
57+
{
58+
// Assign AzureSession.Instance back to oldSession
59+
AzureSession.Initialize(() => oldSession, true);
60+
}
61+
62+
private class TestComponent
63+
{
64+
public string Name { get; set; }
65+
private int id;
66+
private ConcurrentQueue<int> clist = null;
67+
public int Size
68+
{
69+
get => clist.Count();
70+
}
71+
72+
public int Id
73+
{
74+
get => id;
75+
}
76+
77+
public TestComponent(string name, int id)
78+
{
79+
Name = string.Empty;
80+
clist = new ConcurrentQueue<int>();
81+
this.id = id;
82+
}
83+
84+
public void Append(int i)
85+
{
86+
clist.Enqueue(i);
87+
}
88+
}
89+
90+
private Object lockObject = new Object();
91+
92+
// Function to register and retrieve component
93+
private Dictionary<string, int> RegisterAndRetrieveComponent(string componentName, int componentValue, bool overwritten)
94+
{
95+
AzureSession.Instance.RegisterComponent(componentName, () => new TestComponent(componentName, componentValue), overwritten);
96+
AzureSession.Instance.TryGetComponent(componentName, out TestComponent component);
97+
lock (lockObject)
98+
{
99+
component.Append(1);
100+
return new Dictionary<string, int>
101+
{
102+
{ "id", component.Id },
103+
{ "size", component.Size }
104+
};
105+
}
106+
}
107+
108+
[Fact]
109+
public void TestClearComponents()
110+
{
111+
string testComponent1 = "TestComponent1";
112+
string testComponent2 = "TestComponent2";
113+
114+
// Register components
115+
AzureSession.Instance.RegisterComponent(testComponent1, () => "Value1");
116+
AzureSession.Instance.RegisterComponent(testComponent2, () => "Value2");
117+
118+
// Clear all components
119+
AzureSession.Instance.ClearComponents();
120+
121+
// Verify they are gone
122+
Assert.False(AzureSession.Instance.TryGetComponent(testComponent1, out string _));
123+
Assert.False(AzureSession.Instance.TryGetComponent(testComponent2, out string _));
124+
}
125+
126+
[Fact]
127+
public void TestComponentRegistrationDifferentComponentNoOverwritten()
128+
{
129+
string testComponent = "TestComponent";
130+
131+
var tasks = new List<Task<Dictionary<string, int>>>();
132+
for (int i = 0; i < 10; i++)
133+
{
134+
tasks.Add(new Task<Dictionary<string,int>>(
135+
(object state) =>
136+
{
137+
int i = (int)state;
138+
return RegisterAndRetrieveComponent($"{testComponent}{i}", i, false);
139+
},
140+
i));
141+
}
142+
143+
foreach(var task in tasks)
144+
{
145+
task.Start();
146+
}
147+
Task.WaitAll(tasks.ToArray());
148+
149+
// Verify the results
150+
for (int i = 0; i < 10; i++)
151+
{
152+
var result = tasks[i].Result;
153+
Assert.Equal(1, result["size"]);
154+
}
155+
AzureSession.Instance.ClearComponents();
156+
}
157+
158+
[Fact]
159+
public void TestComponentRegistrationDifferentComponentOverwritten()
160+
{
161+
string testComponent = "TestComponent";
162+
163+
var tasks = new List<Task<Dictionary<string, int>>>();
164+
for (int i = 0; i < 10; i++)
165+
{
166+
tasks.Add(new Task<Dictionary<string, int>>(
167+
(object state) =>
168+
{
169+
int i = (int)state;
170+
return RegisterAndRetrieveComponent($"{testComponent}{i}", i, true);
171+
},
172+
i));
173+
}
174+
175+
foreach (var task in tasks)
176+
{
177+
task.Start();
178+
}
179+
Task.WaitAll(tasks.ToArray());
180+
181+
// Verify the results
182+
for (int i = 0; i < 10; i++)
183+
{
184+
var result = tasks[i].Result;
185+
Assert.Equal(1, result["size"]);
186+
}
187+
AzureSession.Instance.ClearComponents();
188+
}
189+
190+
[Fact]
191+
public void TestComponentRegistrationSameComponentNoOverwritten()
192+
{
193+
string testComponent = "TestComponent";
194+
195+
// Create 10 tasks to run the function in parallel
196+
var tasks = new List<Task<Dictionary<string, int>>>();
197+
for (int i = 0; i < 10; i++)
198+
{
199+
tasks.Add(new Task<Dictionary<string, int>>(
200+
(object state) =>
201+
{
202+
int i = (int)state;
203+
return RegisterAndRetrieveComponent(testComponent, i, false);
204+
},
205+
i));
206+
}
207+
208+
foreach (var task in tasks)
209+
{
210+
task.Start();
211+
}
212+
213+
Task.WaitAll(tasks.ToArray());
214+
215+
// Verify the results
216+
var results = new int[10];
217+
218+
Assert.Single(tasks.Select(t => t.Result["id"]).Distinct());
219+
var checkList = tasks.Select(t => t.Result["size"]);
220+
Assert.Equal(10, checkList.Distinct().Count());
221+
Assert.Equal(10, checkList.Max());
222+
Assert.Equal(1, checkList.Min());
223+
AzureSession.Instance.ClearComponents();
224+
}
225+
226+
[Fact]
227+
public void TestComponentRegistrationSameComponentOverwritten()
228+
{
229+
string testComponent = "TestComponent";
230+
231+
// Create 10 tasks to run the function in parallel
232+
var tasks = new List<Task<Dictionary<string, int>>>();
233+
for (int i = 0; i < 10; i++)
234+
{
235+
tasks.Add(new Task<Dictionary<string, int>>(
236+
(object state) =>
237+
{
238+
int i = (int)state;
239+
return RegisterAndRetrieveComponent(testComponent, i, true);
240+
},
241+
i));
242+
}
243+
244+
foreach (var task in tasks)
245+
{
246+
task.Start();
247+
}
248+
249+
Task.WaitAll(tasks.ToArray());
250+
251+
// Verify the results
252+
AzureSession.Instance.TryGetComponent(testComponent, out TestComponent component);
253+
Assert.Equal(1, component.Size);
254+
void CheckResults(List<Task<Dictionary<string, int>>> tasks, int id)
255+
{
256+
var checkList = tasks.Where(t => t.Result["id"] == id);
257+
var count = checkList.Count();
258+
Assert.Equal(count, checkList.Distinct().Count());
259+
if (count > 0)
260+
{
261+
Assert.Equal(count, checkList?.Select(t => t.Result["size"])?.Max());
262+
}
263+
}
264+
for (int i = 0; i < 10; i++)
265+
{
266+
Console.WriteLine($"id={i}");
267+
CheckResults(tasks, i);
268+
}
269+
AzureSession.Instance.ClearComponents();
270+
}
271+
272+
[Fact]
273+
public void TestComponentRegistrationAndUnregistrationInDifferentThreads()
274+
{
275+
string[] testComponents = { "TestComponent1", "TestComponent2", "TestComponent3" };
276+
string componentValue = "TestValue";
277+
278+
Func<string, string, string> RegisterAndUnregisterComponent = (testComponent, componentValue) =>
279+
{
280+
var taskRegister = Task.Run(() => AzureSession.Instance.RegisterComponent(testComponent, () => componentValue));
281+
taskRegister.Wait();
282+
283+
Assert.True(AzureSession.Instance.TryGetComponent(testComponent, out string retrievedValue));
284+
Assert.Equal(componentValue, retrievedValue);
285+
286+
var unregisterTask = Task.Run(() => AzureSession.Instance.UnregisterComponent<string>(testComponent));
287+
unregisterTask.Wait();
288+
return retrievedValue;
289+
};
290+
291+
// Register components in parallel
292+
var tasks = new List<Task>();
293+
foreach (var component in testComponents)
294+
{
295+
Task.Run(() => RegisterAndUnregisterComponent(component, componentValue)).ContinueWith(t => tasks.Add(t));
296+
}
297+
298+
// Wait for all register tasks to complete
299+
Task.WaitAll(tasks.ToArray());
300+
301+
// Verify components are unregistered
302+
foreach (var component in testComponents)
303+
{
304+
Assert.False(AzureSession.Instance.TryGetComponent(component, out string _));
305+
}
306+
}
307+
308+
[Fact]
309+
public void TestEventHandler()
310+
{
311+
bool eventRaised = false;
312+
var listener = new TestSessionListener(() => eventRaised = true);
313+
314+
AzureSession.Instance.RegisterComponent("listener", () => listener);
315+
AzureSession.Instance.RaiseContextClearedEvent();
316+
Assert.True(eventRaised);
317+
318+
eventRaised = false;
319+
AzureSession.Instance.UnregisterComponent<TestSessionListener>("listener");
320+
AzureSession.Instance.RaiseContextClearedEvent();
321+
Assert.False(eventRaised);
322+
}
323+
324+
private class TestSessionListener : IAzureSessionListener
325+
{
326+
private Action _callback;
327+
328+
public TestSessionListener(Action callback)
329+
{
330+
_callback = callback;
331+
}
332+
333+
public void OnEvent(object sender, AzureSessionEventArgs e)
334+
{
335+
if (e.Type == AzureSessionEventType.ContextCleared)
336+
{
337+
_callback();
338+
}
339+
}
340+
}
341+
}
342+
}

src/Authentication.Abstractions/AzureSession.cs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -237,23 +237,31 @@ public void RegisterComponent<T>(string componentName, Func<T> componentInitiali
237237
ChangeRegistry(
238238
() =>
239239
{
240-
var key = new ComponentKey(componentName, typeof(T));
241-
if (!_componentRegistry.ContainsKey(key) || overwrite) // only proceed if key not found or overwrite is true
242-
{
243-
244-
if (overwrite
245-
&& _componentRegistry.TryGetValue(key, out var existed)
246-
&& existed is IAzureSessionListener existedListener)
247-
{
248-
_eventHandler -= existedListener.OnEvent;
249-
}
250-
251-
var component = componentInitializer();
252-
_componentRegistry[key] = component;
253-
if (component is IAzureSessionListener listener)
240+
object oldComponent = null;
241+
bool hasUpdate = true;
242+
var newComponent = _componentRegistry.AddOrUpdate(
243+
new ComponentKey(componentName, typeof(T)),
244+
k => componentInitializer(),
245+
(k, v) =>
254246
{
255-
_eventHandler += listener.OnEvent;
256-
}
247+
if (!overwrite)
248+
{
249+
hasUpdate = false;
250+
return v;
251+
}
252+
else
253+
{
254+
oldComponent = v;
255+
return componentInitializer();
256+
}
257+
});
258+
if (oldComponent is IAzureSessionListener oldListener)
259+
{
260+
_eventHandler -= oldListener.OnEvent;
261+
}
262+
if (hasUpdate && newComponent is IAzureSessionListener listener)
263+
{
264+
_eventHandler += listener.OnEvent;
257265
}
258266
});
259267
}

0 commit comments

Comments
 (0)