Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions PSql/Resolve-SqlClient.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ else
{
Add-Type -Path (Join-Path $PSScriptRoot runtimes/unix/lib/netcoreapp3.1/Microsoft.Data.SqlClient.dll)
}

# Required for Azure Active Directory authentication modes
Add-Type -Path (Join-Path $PSScriptRoot Microsoft.Identity.Client.dll)
37 changes: 30 additions & 7 deletions PSql/_Commands/NewSqlContextCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,35 @@ private const string

// -ResourceGroupName
[Alias("ResourceGroup")]
[Parameter(ParameterSetName = AzureName, Position = 1, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, ValueFromPipelineByPropertyName = true)]
[ValidateNotNullOrEmpty]
public string ResourceGroupName { get; set; }

// -ServerName
[Alias("Server")]
[Parameter(ParameterSetName = GenericName, Position = 0, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 2, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 1, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[ValidateNotNullOrEmpty]
public string ServerName { get; set; }

// -DatabaseName
[Alias("Database")]
[Parameter(ParameterSetName = GenericName, Position = 1, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 3, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 2, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[ValidateNotNullOrEmpty]
public string DatabaseName { get; set; }

// -Credential
[Parameter(ParameterSetName = GenericName, Position = 2, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 4, Mandatory = true, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = GenericName, Position = 2, ValueFromPipelineByPropertyName = true)]
[Parameter(ParameterSetName = AzureName, Position = 3, ValueFromPipelineByPropertyName = true)]
[Credential]
public PSCredential Credential { get; set; } = PSCredential.Empty;

// -AuthenticationMode
[Alias("Auth")]
[Parameter(ParameterSetName = AzureName, ValueFromPipelineByPropertyName = true)]
public AzureAuthenticationMode AuthenticationMode { get; set; }

// -EncryptionMode
[Alias("Encryption")]
[Parameter(ParameterSetName = GenericName, ValueFromPipelineByPropertyName = true)]
Expand Down Expand Up @@ -70,11 +75,25 @@ private const string
[ValidateRange("0:00:00", "24855.03:14:07")]
public TimeSpan? ConnectTimeout { get; set; }

// -ExposeCredentialInConnectionString
[Parameter(ValueFromPipelineByPropertyName = true)]
public SwitchParameter ExposeCredentialInConnectionString { get; set; }

// -Pooling
[Parameter(ValueFromPipelineByPropertyName = true)]
public SwitchParameter Pooling { get; set; }

// -MultipleActiveResultSets
[Alias("Mars")]
[Parameter(ValueFromPipelineByPropertyName = true)]
public SwitchParameter MultipleActiveResultSets { get; set; } = true;

protected override void ProcessRecord()
{
var context = Azure.IsPresent
? new AzureSqlContext { ResourceGroupName = ResourceGroupName }
: new SqlContext { EncryptionMode = EncryptionMode };
? new AzureSqlContext { ResourceGroupName = ResourceGroupName ,
AuthenticationMode = AuthenticationMode }
: new SqlContext { EncryptionMode = EncryptionMode };

var credential = Credential.IsNullOrEmpty()
? null
Expand All @@ -88,6 +107,10 @@ protected override void ProcessRecord()
context.ApplicationName = ApplicationName;
context.ApplicationIntent = ReadOnlyIntent ? ReadOnly : ReadWrite;

context.ExposeCredentialInConnectionString = ExposeCredentialInConnectionString;
context.EnableConnectionPooling = Pooling;
context.EnableMultipleActiveResultSets = MultipleActiveResultSets;

WriteObject(context);
}
}
Expand Down
55 changes: 55 additions & 0 deletions PSql/_Data/AzureAuthenticationMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Sam = Microsoft.Data.SqlClient.SqlAuthenticationMethod;

namespace PSql
{
/// <summary>
/// Modes for authentiating connections to Azure SQL Database and
/// compatible databases.
/// </summary>
public enum AzureAuthenticationMode
{
/// <summary>
/// Default authentication mode. The actual authentication mode
/// depends on the value of the <see cref="SqlContext.Credential"/>
/// property. If the property is non-<c>null</c>, this mode selects
/// SQL authentication using the credential. If the property is
/// <c>null</c>, this mode selects Azure AD integrated
/// authentication.
/// </summary>
Default = Sam.NotSpecified,

/// <summary>
/// SQL authentication mode. The <see cref="SqlContext.Credential"/>
/// property should contain the name and password stored for a server
/// login or contained database user.
/// </summary>
SqlPassword = Sam.SqlPassword,

/// <summary>
/// Azure Active Directory password authentication mode. The
/// <see cref="SqlContext.Credential"/> property should contain the
/// name and password of an Azure AD principal.
/// </summary>
AadPassword = Sam.ActiveDirectoryPassword,

/// <summary>
/// Azure Active Directory integrated authentication mode. The
/// identity of the process should be an Azure AD principal.
/// </summary>
AadIntegrated = Sam.ActiveDirectoryIntegrated,

/// <summary>
/// Azure Active Directory interactive authentication mode, also
/// known as Universal Authentication with MFA. Authentication uses
/// an interactive flow and supports multiple factors.
/// </summary>
AadInteractive = Sam.ActiveDirectoryInteractive,

/// <summary>
/// Azure Active Directory service principal authentication mode.
/// The <see cref="SqlContext.Credential"/> property contains the
/// client ID and secret of an Azure AD service principal.
/// </summary>
AadServicePrincipal = Sam.ActiveDirectoryServicePrincipal
}
}
49 changes: 46 additions & 3 deletions PSql/_Data/AzureSqlContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ public class AzureSqlContext : SqlContext
{
public AzureSqlContext()
{
// Encryption is required for connections to Azure SQL Database
EncryptionMode = EncryptionMode.Full;
}

public string ResourceGroupName { get; set; }

public string ServerFullName { get; private set; }

public AzureAuthenticationMode AuthenticationMode { get; set; }

protected override void BuildConnectionString(SqlConnectionStringBuilder builder)
{
if (Credential.IsNullOrEmpty())
throw new NotSupportedException("A credential is required when connecting to Azure SQL Database.");

base.BuildConnectionString(builder);

builder.DataSource = ServerFullName ?? ResolveServerFullName();
Expand All @@ -34,13 +34,56 @@ protected override void BuildConnectionString(SqlConnectionStringBuilder builder
builder.InitialCatalog = MasterDatabaseName;
}

protected override void ConfigureAuthentication(SqlConnectionStringBuilder builder)
{
var auth = (SqlAuthenticationMethod) AuthenticationMode;

switch (auth)
{
case SqlAuthenticationMethod.NotSpecified when Credential != null:
auth = SqlAuthenticationMethod.SqlPassword;
break;

case SqlAuthenticationMethod.NotSpecified:
auth = SqlAuthenticationMethod.ActiveDirectoryIntegrated;
break;

case SqlAuthenticationMethod.SqlPassword:
case SqlAuthenticationMethod.ActiveDirectoryPassword:
case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal:
if (Credential.IsNullOrEmpty())
throw new NotSupportedException("A credential is required when connecting to Azure SQL Database.");
break;
}

builder.Authentication = auth;
}

protected override void ConfigureEncryption(SqlConnectionStringBuilder builder)
{
// Encryption is required for connections to Azure SQL Database
builder.Encrypt = true;

// Always verify server identity
// builder.TrustServerCertificate defaults to false
}

private string ResolveServerFullName()
{
// Check if ServerName should be used as ServerFullName verbatim

if (string.IsNullOrEmpty(ServerName))
throw new InvalidOperationException("ServerName is required.");

var shouldUseServerNameVerbatim
= ServerName.Contains('.', StringComparison.Ordinal)
|| string.IsNullOrEmpty(ResourceGroupName);

if (shouldUseServerNameVerbatim)
return ServerName;

// Resolve ServerFullName using Az cmdlets

var value = ScriptBlock
.Create("param ($x) Get-AzSqlServer @x -ea Stop")
.Invoke(new Dictionary<string, object>
Expand Down
39 changes: 27 additions & 12 deletions PSql/_Data/SqlContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ protected const string

public ApplicationIntent ApplicationIntent { get; set; }

public bool ExposeCredentialInConnectionString { get; set; }

public bool EnableConnectionPooling { get; set; }

public bool EnableMultipleActiveResultSets { get; set; }

internal SqlConnection CreateConnection(string databaseName)
{
var builder = new SqlConnectionStringBuilder();
Expand Down Expand Up @@ -62,14 +68,9 @@ protected virtual void BuildConnectionString(SqlConnectionStringBuilder builder)
//else
// server determines database

// Authentication
if (Credential.IsNullOrEmpty())
builder.IntegratedSecurity = true;
//else
// will provide credential as a SqlCredential object

// Encryption & Server Identity Check
ConfigureEncryption(builder);
// Security
ConfigureAuthentication (builder);
ConfigureEncryption (builder);

// Timeout
if (ConnectTimeout.HasValue)
Expand All @@ -88,7 +89,18 @@ protected virtual void BuildConnectionString(SqlConnectionStringBuilder builder)
builder.ApplicationIntent = ApplicationIntent;

// Other
builder.Pooling = false;
builder.PersistSecurityInfo = ExposeCredentialInConnectionString;
builder.MultipleActiveResultSets = EnableMultipleActiveResultSets;
builder.Pooling = EnableConnectionPooling;
}

protected virtual void ConfigureAuthentication(SqlConnectionStringBuilder builder)
{
// Authentication
if (Credential.IsNullOrEmpty())
builder.IntegratedSecurity = true;
//else
// will provide credential as a SqlCredential object
}

protected virtual void ConfigureEncryption(SqlConnectionStringBuilder builder)
Expand All @@ -105,11 +117,14 @@ protected virtual void ConfigureEncryption(SqlConnectionStringBuilder builder)

private (bool, bool) TranslateEncryptionMode(EncryptionMode mode)
{
// tuple: (useEncryption, useServerIdentityCheck)

switch (mode)
{
case EncryptionMode.None: return (false, false);
case EncryptionMode.Unverified: return (true, false);
case EncryptionMode.Full: return (true, true );
// ( ENCRYPT, VERIFY )
case EncryptionMode.None: return ( false, false );
case EncryptionMode.Unverified: return ( true, false );
case EncryptionMode.Full: return ( true, true );
case EncryptionMode.Default:
default:
var isRemote = !GetIsLocal();
Expand Down
Loading