Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ private sealed class PredictionRequestBody
{
public sealed class RequestContext
{
public string CorrelationId { get; set; } = Guid.Empty.ToString();
public string SessionId { get; set; } = Guid.Empty.ToString();
public string SubscriptionId { get; set; } = Guid.Empty.ToString();
public Version VersionNumber{ get; set; } = new Version(0, 0);
}
Expand All @@ -57,6 +55,11 @@ private sealed class CommandRequestContext
public Version VersionNumber{ get; set; } = new Version(0, 0);
}

/// <summary>
/// The name of the header value that contains the platform correlation id.
/// </summary>
private const string CorrelationIdHeader = "Sml-CorrelationId";

private const string ThrottleByIdHeader = "X-UserId";
private readonly HttpClient _client;
private readonly string _commandsEndpoint;
Expand Down Expand Up @@ -285,12 +288,10 @@ public virtual void RequestPredictions(IEnumerable<string> commands)
Task.Run(async () => {
try
{
AzPredictorService.ReplaceThrottleUserIdToHeader(_client?.DefaultRequestHeaders, _azContext.UserId);
AzPredictorService.SetHttpRequestHeader(_client?.DefaultRequestHeaders, _azContext.UserId, _telemetryClient.CorrelationId);

var requestContext = new PredictionRequestBody.RequestContext()
{
SessionId = _telemetryClient.SessionId,
CorrelationId = _telemetryClient.CorrelationId,
VersionNumber = this._azContext.AzVersion
};

Expand Down Expand Up @@ -358,7 +359,7 @@ protected virtual void RequestAllPredictiveCommands()

try
{
_client.DefaultRequestHeaders?.Add(AzPredictorService.ThrottleByIdHeader, _azContext.UserId);
AzPredictorService.SetHttpRequestHeader(_client.DefaultRequestHeaders, _azContext.UserId, _telemetryClient.CorrelationId);

var httpResponseMessage = await _client.GetAsync(_commandsEndpoint);

Expand Down Expand Up @@ -427,21 +428,27 @@ private static string GetCommandName(string commandLine)
return commandLine.Split(AzPredictorConstants.CommandParameterSeperator).First();
}

private static void ReplaceThrottleUserIdToHeader(HttpRequestHeaders header, string value)
private static void SetHttpRequestHeader(HttpRequestHeaders header, string idToThrottle, string correlationId)
{
if (header != null)
{
lock (header)
{
header.Remove(AzPredictorService.ThrottleByIdHeader);

if (!string.IsNullOrWhiteSpace(value))
if (!string.IsNullOrWhiteSpace(idToThrottle))
{
header.Add(AzPredictorService.ThrottleByIdHeader, idToThrottle);
}

header.Remove(AzPredictorService.CorrelationIdHeader);

if (!string.IsNullOrWhiteSpace(correlationId))
{
header.Add(AzPredictorService.ThrottleByIdHeader, value);
header.Add(AzPredictorService.CorrelationIdHeader, correlationId);
}
}
}

}
}
}