Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Net;

namespace Microsoft.Extensions.ServiceDiscovery.Dns;

/// <summary>
/// Provides configuration options for DNS resolution, including server endpoints, retry attempts, and timeout settings.
/// </summary>
public class DnsResolverOptions
{
/// <summary>
/// Gets or sets the collection of server endpoints used for network connections.
/// </summary>
public IList<IPEndPoint> Servers { get; set; } = new List<IPEndPoint>();

/// <summary>
/// Gets or sets the maximum number of attempts per server.
/// </summary>
public int MaxAttempts { get; set; } = 2;

/// <summary>
/// Gets or sets the maximum duration per attempt to wait before timing out.
/// </summary>
/// <remarks>
/// The maximum time for resolving a query is <see cref="MaxAttempts"/> * <see cref="Servers"/> count * <see cref="Timeout"/>.
/// </remarks>
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(3);

// override for testing purposes
internal Func<Memory<byte>, int, int>? _transportOverride;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using Microsoft.Extensions.Options;

namespace Microsoft.Extensions.ServiceDiscovery.Dns;

internal sealed class DnsResolverOptionsValidator : IValidateOptions<DnsResolverOptions>
{
// CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds.
private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue);

public ValidateOptionsResult Validate(string? name, DnsResolverOptions options)
{
if (options.Servers is null)
{
return ValidateOptionsResult.Fail($"{nameof(options.Servers)} must not be null.");
}

if (options.MaxAttempts < 1)
{
return ValidateOptionsResult.Fail($"{nameof(options.MaxAttempts)} must be one or greater.");
}

if (options.Timeout != Timeout.InfiniteTimeSpan)
{
if (options.Timeout <= TimeSpan.Zero)
{
return ValidateOptionsResult.Fail($"{nameof(options.Timeout)} must not be negative or zero.");
}

if (options.Timeout > s_maxTimeout)
{
return ValidateOptionsResult.Fail($"{nameof(options.Timeout)} must not be greater than {s_maxTimeout.TotalMilliseconds} milliseconds.");
}
}

return ValidateOptionsResult.Success;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ internal sealed partial class DnsSrvServiceEndpointProviderFactory(
/// <inheritdoc/>
public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] out IServiceEndpointProvider? provider)
{
var optionsValue = options.CurrentValue;

// If a default namespace is not specified, then this provider will attempt to infer the namespace from the service name, but only when running inside Kubernetes.
// Kubernetes DNS spec: https://github.com/kubernetes/dns/blob/master/docs/specification.md
// SRV records are available for headless services with named ports.
Expand All @@ -30,19 +32,26 @@ public bool TryCreateProvider(ServiceEndpointQuery query, [NotNullWhen(true)] ou
// Otherwise, the namespace can be read from /var/run/secrets/kubernetes.io/serviceaccount/namespace and combined with an assumed suffix of "svc.cluster.local".
// The protocol is assumed to be "tcp".
// The portName is the name of the port in the service definition. If the serviceName parses as a URI, we use the scheme as the port name, otherwise "default".
if (string.IsNullOrWhiteSpace(_querySuffix))
if (optionsValue.ServiceDomainNameCallback == null && string.IsNullOrWhiteSpace(_querySuffix))
{
DnsServiceEndpointProviderBase.Log.NoDnsSuffixFound(logger, query.ToString()!);
provider = default;
return false;
}

var portName = query.EndpointName ?? "default";
var srvQuery = $"_{portName}._tcp.{query.ServiceName}.{_querySuffix}";
var srvQuery = optionsValue.ServiceDomainNameCallback != null
? optionsValue.ServiceDomainNameCallback(query)
: DefaultServiceDomainNameCallback(query, optionsValue);
provider = new DnsSrvServiceEndpointProvider(query, srvQuery, hostName: query.ServiceName, options, logger, resolver, timeProvider);
return true;
}

private static string DefaultServiceDomainNameCallback(ServiceEndpointQuery query, DnsSrvServiceEndpointProviderOptions options)
{
var portName = query.EndpointName ?? "default";
return $"_{portName}._tcp.{query.ServiceName}.{options.QuerySuffix}";
}

private static string? GetKubernetesHostDomain()
{
// Check that we are running in Kubernetes first.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public class DnsSrvServiceEndpointProviderOptions
/// </remarks>
public string? QuerySuffix { get; set; }

/// <summary>
/// Gets or sets a delegate that generates a DNS SRV query from a specified <see cref="ServiceEndpointQuery"/> instance.
/// </summary>
public Func<ServiceEndpointQuery, string>? ServiceDomainNameCallback { get; set; }

/// <summary>
/// Gets or sets a delegate used to determine whether to apply host name metadata to each resolved endpoint. Defaults to <c>false</c>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Security.Cryptography;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;

namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;

Expand All @@ -19,43 +20,40 @@ internal sealed partial class DnsResolver : IDnsResolver, IDisposable
private const int IPv4Length = 4;
private const int IPv6Length = 16;

// CancellationTokenSource.CancelAfter has a maximum timeout of Int32.MaxValue milliseconds.
private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue);

private bool _disposed;
private readonly ResolverOptions _options;
private readonly DnsResolverOptions _options;
private readonly CancellationTokenSource _pendingRequestsCts = new();
private readonly TimeProvider _timeProvider;
private readonly ILogger<DnsResolver> _logger;

public DnsResolver(TimeProvider timeProvider, ILogger<DnsResolver> logger) : this(timeProvider, logger, OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() ? ResolvConf.GetOptions() : NetworkInfo.GetOptions())
{
}

internal DnsResolver(TimeProvider timeProvider, ILogger<DnsResolver> logger, ResolverOptions options)
public DnsResolver(TimeProvider timeProvider, ILogger<DnsResolver> logger, IOptions<DnsResolverOptions> options)
{
_timeProvider = timeProvider;
_logger = logger;
_options = options;
Debug.Assert(_options.Servers.Count > 0);
_options = options.Value;

if (options.Timeout != Timeout.InfiniteTimeSpan)
if (_options.Servers.Count == 0)
{
ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(options.Timeout, TimeSpan.Zero);
ArgumentOutOfRangeException.ThrowIfGreaterThan(options.Timeout, s_maxTimeout);
}
}

internal DnsResolver(ResolverOptions options) : this(TimeProvider.System, NullLogger<DnsResolver>.Instance, options)
{
}
foreach (var server in OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()
? ResolvConf.GetServers()
: NetworkInfo.GetServers())
{
_options.Servers.Add(server);
}

internal DnsResolver(IEnumerable<IPEndPoint> servers) : this(new ResolverOptions(servers.ToArray()))
{
if (_options.Servers.Count == 0)
{
throw new ArgumentException("At least one DNS server is required.", nameof(options));
}
}
}

internal DnsResolver(IPEndPoint server) : this(new ResolverOptions(server))
// This constructor is for unit testing only. Does not auto-add system DNS servers.
internal DnsResolver(DnsResolverOptions options)
{
_timeProvider = TimeProvider.System;
_logger = NullLogger<DnsResolver>.Instance;
_options = options;
}

public ValueTask<ServiceResult[]> ResolveServiceAsync(string name, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -365,7 +363,7 @@ internal struct SendQueryResult
{
IPEndPoint serverEndPoint = _options.Servers[index];

for (int attempt = 1; attempt <= _options.Attempts; attempt++)
for (int attempt = 1; attempt <= _options.MaxAttempts; attempt++)
{
DnsResponse response = default;
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;

internal static class NetworkInfo
{
// basic option to get DNS serves via NetworkInfo. We may get it directly later via proper APIs.
public static ResolverOptions GetOptions()
// basic option to get DNS servers via NetworkInfo. We may get it directly later via proper APIs.
public static IList<IPEndPoint> GetServers()
{
List<IPEndPoint> servers = new List<IPEndPoint>();

Expand All @@ -31,6 +31,6 @@ public static ResolverOptions GetOptions()
}
}

return new ResolverOptions(servers);
return servers;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ internal static class ResolvConf
{
[SupportedOSPlatform("linux")]
[SupportedOSPlatform("osx")]
public static ResolverOptions GetOptions()
public static IList<IPEndPoint> GetServers()
{
return GetOptions(new StreamReader("/etc/resolv.conf"));
return GetServers(new StreamReader("/etc/resolv.conf"));
}

public static ResolverOptions GetOptions(TextReader reader)
public static IList<IPEndPoint> GetServers(TextReader reader)
{
List<IPEndPoint> serverList = new();

Expand All @@ -40,9 +40,9 @@ public static ResolverOptions GetOptions(TextReader reader)
if (serverList.Count == 0)
{
// If no nameservers are configured, fall back to the default behavior of using the system resolver configuration.
return NetworkInfo.GetOptions();
return NetworkInfo.GetServers();
}

return new ResolverOptions(serverList);
return serverList;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.ServiceDiscovery;
using Microsoft.Extensions.ServiceDiscovery.Dns;
using Microsoft.Extensions.ServiceDiscovery.Dns.Resolver;
Expand Down Expand Up @@ -59,24 +60,10 @@ public static IServiceCollection AddDnsSrvServiceEndpointProvider(this IServiceC

services.AddSingleton<IServiceEndpointProviderFactory, DnsSrvServiceEndpointProviderFactory>();
var options = services.AddOptions<DnsSrvServiceEndpointProviderOptions>();
options.Configure(o => configureOptions?.Invoke(o));
options.Configure(configureOptions);

return services;

static bool GetDnsClientFallbackFlag()
{
if (AppContext.TryGetSwitch("Microsoft.Extensions.ServiceDiscovery.Dns.UseDnsClientFallback", out var value))
{
return value;
}

var envVar = Environment.GetEnvironmentVariable("MICROSOFT_EXTENSIONS_SERVICE_DISCOVERY_DNS_USE_DNSCLIENT_FALLBACK");
if (envVar is not null && (envVar.Equals("true", StringComparison.OrdinalIgnoreCase) || envVar.Equals("1")))
{
return true;
}

return false;
}
}

/// <summary>
Expand Down Expand Up @@ -109,9 +96,55 @@ public static IServiceCollection AddDnsServiceEndpointProvider(this IServiceColl
ArgumentNullException.ThrowIfNull(configureOptions);

services.AddServiceDiscoveryCore();

if (!GetDnsClientFallbackFlag())
{
services.TryAddSingleton<IDnsResolver, DnsResolver>();
}
else
{
services.TryAddSingleton<IDnsResolver, FallbackDnsResolver>();
services.TryAddSingleton<DnsClient.LookupClient>();
}

services.AddSingleton<IServiceEndpointProviderFactory, DnsServiceEndpointProviderFactory>();
var options = services.AddOptions<DnsServiceEndpointProviderOptions>();
options.Configure(o => configureOptions?.Invoke(o));
options.Configure(configureOptions);

return services;
}

/// <summary>
/// Configures the DNS resolver used for service discovery.
/// </summary>
/// <param name="services">The service collection.</param>
/// <param name="configureOptions">The DNS resolver options.</param>
/// <returns>The provided <see cref="IServiceCollection"/>.</returns>
public static IServiceCollection ConfigureDnsResolver(this IServiceCollection services, Action<DnsResolverOptions> configureOptions)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(configureOptions);

var options = services.AddOptions<DnsResolverOptions>();
options.Configure(configureOptions);
services.AddTransient<IValidateOptions<DnsResolverOptions>, DnsResolverOptionsValidator>();
return services;
}

private static bool GetDnsClientFallbackFlag()
{
if (AppContext.TryGetSwitch("Microsoft.Extensions.ServiceDiscovery.Dns.UseDnsClientFallback", out var value))
{
return value;
}

var envVar = Environment.GetEnvironmentVariable("MICROSOFT_EXTENSIONS_SERVICE_DISCOVERY_DNS_USE_DNSCLIENT_FALLBACK");
if (envVar is not null && (envVar.Equals("true", StringComparison.OrdinalIgnoreCase) || envVar.Equals("1")))
{
return true;
}

return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ public void FuzzTarget(ReadOnlySpan<byte> data)
if (_resolver == null)
{
_buffer = new byte[4096];
_resolver = new DnsResolver(new ResolverOptions(new IPEndPoint(IPAddress.Loopback, 53))
_resolver = new DnsResolver(new DnsResolverOptions
{
Servers = [new IPEndPoint(IPAddress.Loopback, 53)],
Timeout = TimeSpan.FromSeconds(5),
Attempts = 1,
MaxAttempts = 1,
_transportOverride = (buffer, length) =>
{
// the first two bytes are the random transaction ID, so we keep that
Expand All @@ -41,4 +42,4 @@ public void FuzzTarget(ReadOnlySpan<byte> data)
Debug.Assert(task.IsCompleted, "Task should be completed synchronously");
task.GetAwaiter().GetResult();
}
}
}
Loading
Loading