Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using a managed identity to connect to Azure Database for PostgreSQL #294

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c92b259
update nuget packages and initial code implementation
MattMcL4475 Jul 18, 2023
3559b8b
refactor
MattMcL4475 Jul 18, 2023
ff1736c
update logic
MattMcL4475 Jul 18, 2023
820c2e0
refactor
MattMcL4475 Jul 18, 2023
ecc5856
fix connection string logic
MattMcL4475 Jul 18, 2023
7795580
rewrite for clarity
MattMcL4475 Jul 18, 2023
01dd40a
refactor
MattMcL4475 Jul 18, 2023
cf72264
add new setting
MattMcL4475 Jul 18, 2023
e4ccb52
fix semicolon
MattMcL4475 Jul 18, 2023
9335a3e
add note
MattMcL4475 Jul 18, 2023
4954869
add additional argument exception check
MattMcL4475 Jul 18, 2023
df02e60
update casing
MattMcL4475 Jul 18, 2023
6c827f1
add note about connection string check
MattMcL4475 Jul 18, 2023
dd03d82
Merge remote-tracking branch 'origin/main' into feature/AddManagedIde…
MattMcL4475 Jul 18, 2023
1022f25
fix formatting
MattMcL4475 Jul 18, 2023
e7722b0
hoist maxbatchsize
MattMcL4475 Jul 18, 2023
7eba550
minor formatting
MattMcL4475 Jul 18, 2023
d6d1861
Merge remote-tracking branch 'origin/main' into feature/AddManagedIde…
MattMcL4475 Jul 27, 2023
1402975
Merge remote-tracking branch 'origin/main' into feature/AddManagedIde…
MattMcL4475 Aug 15, 2023
9bf1c3e
add Microsoft.EntityFrameworkCore.Relational and update packages
MattMcL4475 Aug 22, 2023
60cbf7c
Merge branch 'main' into feature/AddManagedIdentitySupportToPostgres
MattMcL4475 Oct 3, 2023
866f458
Merge remote-tracking branch 'origin/main' into feature/AddManagedIde…
MattMcL4475 Oct 9, 2023
6441716
update implementation and refactor dbcontext
MattMcL4475 Oct 10, 2023
4025abe
Merge remote-tracking branch 'origin/main' into feature/AddManagedIde…
MattMcL4475 Oct 10, 2023
f956900
update nuget
MattMcL4475 Oct 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.10.1" />
<PackageReference Include="Azure.Identity" Version="1.10.2" />
<PackageReference Include="Azure.ResourceManager.Batch" Version="1.1.1" />
<PackageReference Include="Azure.ResourceManager.Compute" Version="1.1.0" />
<PackageReference Include="Azure.Security.KeyVault.Secrets" Version="4.4.0" />
Expand Down
2 changes: 1 addition & 1 deletion src/Tes.ApiClients/Tes.ApiClients.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

<ItemGroup>
<PackageReference Include="Azure.Core" Version="1.35.0" />
<PackageReference Include="Azure.Identity" Version="1.10.1" />
<PackageReference Include="Azure.Identity" Version="1.10.2" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
Expand Down
2 changes: 1 addition & 1 deletion src/Tes.Runner.Test/Tes.Runner.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

<ItemGroup>
<PackageReference Include="Azure.Storage.Blobs" Version="12.16.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.5" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.12" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.5.0" />
<PackageReference Include="Moq" Version="4.18.4" />
<PackageReference Include="MSTest.TestAdapter" Version="3.0.2" />
Expand Down
2 changes: 1 addition & 1 deletion src/Tes.Runner/Tes.Runner.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.10.1" />
<PackageReference Include="Azure.Identity" Version="1.10.2" />
<PackageReference Include="Azure.Storage.Blobs" Version="12.16.0" />
<PackageReference Include="Docker.DotNet" Version="3.125.14" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="7.0.0" />
Expand Down
1 change: 1 addition & 0 deletions src/Tes/Models/PostgreSqlOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ public static string GetConfigurationSectionName(string serviceName = "Tes")
public string DatabaseName { get; set; } = "tes_db";
public string DatabaseUserLogin { get; set; }
public string DatabaseUserPassword { get; set; }
public bool UseManagedIdentity { get; set; }
}
}
2 changes: 1 addition & 1 deletion src/Tes/Models/TesTaskPostgres.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Tes.Models
/// <summary>
/// Database schema for encapsulating a TesTask as Json for Postgresql.
/// </summary>
[Table(Repository.TesDbContext.TesTasksPostgresTableName)]
[Table("testasks")]
public class TesTaskDatabaseItem
{
[Column("id")]
Expand Down
10 changes: 6 additions & 4 deletions src/Tes/Repository/PostgreSqlCachingRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Polly;

namespace Tes.Repository
{
public abstract class PostgreSqlCachingRepository<T> : IDisposable where T : class
{
private readonly IServiceScopeFactory _scopeFactory = null!;
private readonly TimeSpan _writerWaitTime = TimeSpan.FromMilliseconds(50);
private readonly int _batchSize = 1000;
private static readonly TimeSpan defaultCompletedTaskCacheExpiration = TimeSpan.FromDays(1);
Expand All @@ -30,17 +32,16 @@ public abstract class PostgreSqlCachingRepository<T> : IDisposable where T : cla
private readonly Task _writerWorkerTask;

protected enum WriteAction { Add, Update, Delete }

protected Func<TesDbContext> CreateDbContext { get; init; }
protected readonly ICache<T> _cache;
protected readonly ILogger _logger;

private bool _disposedValue;

protected PostgreSqlCachingRepository(ILogger logger = default, ICache<T> cache = default)
protected PostgreSqlCachingRepository(ILogger logger = default, ICache<T> cache = default, IServiceScopeFactory scopeFactory = default)
{
_logger = logger;
_cache = cache;
_scopeFactory = scopeFactory;

// The only "normal" exit for _writerWorkerTask is "cancelled". Anything else should force the process to exit because it means that this repository will no longer write to the database!
_writerWorkerTask = Task.Run(() => WriterWorkerAsync(_writerWorkerCancellationTokenSource.Token))
Expand Down Expand Up @@ -187,7 +188,8 @@ private async ValueTask WriteItemsAsync(IList<(T DbItem, WriteAction Action, Tas
if (dbItems.Count == 0) { return; }

cancellationToken.ThrowIfCancellationRequested();
using var dbContext = CreateDbContext();
using var scope = _scopeFactory.CreateScope();
using var dbContext = scope.ServiceProvider.GetRequiredService<TesDbContext>();

// Manually set entity state to avoid potential NPG PostgreSql bug
dbContext.ChangeTracker.AutoDetectChangesEnabled = false;
Expand Down
26 changes: 13 additions & 13 deletions src/Tes/Repository/TesDbContext.cs
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using Azure.Core;
using Azure.Identity;
using Microsoft.EntityFrameworkCore;
using Tes.Models;
using Tes.Utilities;

namespace Tes.Repository
{
public class TesDbContext : DbContext
{
public const string TesTasksPostgresTableName = "testasks";
private const int maxBatchSize = 1000;
private readonly PostgresConnectionStringUtility connectionStringUtility = null!;

public TesDbContext()
{
// Default constructor, which is required to run the EF migrations tool,
// "dotnet ef migrations add InitialCreate"
// DI will NOT use this constructor
}

public TesDbContext(string connectionString)
public TesDbContext(PostgresConnectionStringUtility connectionStringUtility)
{
ArgumentException.ThrowIfNullOrEmpty(connectionString, nameof(connectionString));
ConnectionString = connectionString;
this.connectionStringUtility = connectionStringUtility;
}

public string ConnectionString { get; set; }
public DbSet<TesTaskDatabaseItem> TesTasks { get; set; }

protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
if (!optionsBuilder.IsConfigured)
{
// use PostgreSQL
optionsBuilder
.UseNpgsql(ConnectionString, options => options.MaxBatchSize(1000))
.UseLowerCaseNamingConvention();
}
string connectionString = this.connectionStringUtility.GetConnectionString().Result;

optionsBuilder
.UseNpgsql(connectionString, options => options.MaxBatchSize(maxBatchSize))
.UseLowerCaseNamingConvention();
}
}
}
30 changes: 11 additions & 19 deletions src/Tes/Repository/TesTaskPostgreSqlRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace Tes.Repository
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Polly;
using Tes.Models;
using Tes.Utilities;
Expand All @@ -23,34 +23,24 @@ namespace Tes.Repository
/// <typeparam name="TesTask"></typeparam>
public sealed class TesTaskPostgreSqlRepository : PostgreSqlCachingRepository<TesTaskDatabaseItem>, IRepository<TesTask>
{
private readonly IServiceScopeFactory _scopeFactory = null!;

/// <summary>
/// Default constructor that also will create the schema if it does not exist
/// </summary>
/// <param name="options"></param>
/// <param name="logger"></param>
/// <param name="cache"></param>
public TesTaskPostgreSqlRepository(IOptions<PostgreSqlOptions> options, ILogger<TesTaskPostgreSqlRepository> logger, ICache<TesTaskDatabaseItem> cache = null)
public TesTaskPostgreSqlRepository(ILogger<TesTaskPostgreSqlRepository> logger = default, IServiceScopeFactory scopeFactory = default, ICache<TesTaskDatabaseItem> cache = null)
: base(logger, cache)
{
var connectionString = new ConnectionStringUtility().GetPostgresConnectionString(options);
CreateDbContext = () => { return new TesDbContext(connectionString); };
using var dbContext = CreateDbContext();
_scopeFactory = scopeFactory;
using var scope = _scopeFactory.CreateScope();
using var dbContext = scope.ServiceProvider.GetRequiredService<TesDbContext>();
dbContext.Database.MigrateAsync().Wait();
WarmCacheAsync(CancellationToken.None).Wait();
}

/// <summary>
/// Constructor for testing to enable mocking DbContext
/// </summary>
/// <param name="createDbContext">A delegate that creates a TesTaskPostgreSqlRepository context</param>
public TesTaskPostgreSqlRepository(Func<TesDbContext> createDbContext)
: base()
{
CreateDbContext = createDbContext;
using var dbContext = createDbContext();
dbContext.Database.MigrateAsync().Wait();
}

private async Task WarmCacheAsync(CancellationToken cancellationToken)
{
if (_cache is null)
Expand Down Expand Up @@ -224,7 +214,8 @@ private async Task<TesTaskDatabaseItem> GetItemFromCacheOrDatabase(string id, bo

if (!_cache?.TryGetValue(id, out item) ?? true)
{
using var dbContext = CreateDbContext();
using var scope = _scopeFactory.CreateScope();
using var dbContext = scope.ServiceProvider.GetRequiredService<TesDbContext>();

// Search for Id within the JSON
item = await _asyncPolicy.ExecuteAsync(ct => dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id, ct), cancellationToken);
Expand Down Expand Up @@ -252,7 +243,8 @@ private async Task<IEnumerable<TesTask>> InternalGetItemsAsync(Expression<Func<T
//orderBy = pagination is null ? orderBy : q => q.OrderBy(t => t.Json.CreationTime).ThenBy(t => t.Json.Id);
orderBy = pagination is null ? orderBy : q => q.OrderBy(t => t.Json.Id);

using var dbContext = CreateDbContext();
using var scope = _scopeFactory.CreateScope();
using var dbContext = scope.ServiceProvider.GetRequiredService<TesDbContext>();
return (await GetItemsAsync(dbContext.TesTasks, WhereTesTask(predicate), cancellationToken, orderBy, pagination)).Select(item => EnsureActiveItemInCache(item, t => t.Json.Id, t => t.Json.IsActiveState()).Json);
}

Expand Down
10 changes: 6 additions & 4 deletions src/Tes/Tes.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.10.2" />
<PackageReference Include="EFCore.NamingConventions" Version="7.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="7.0.3">
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.12" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="7.0.12">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="7.0.3" />
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="7.0.12" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="7.0.12" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="7.0.0" />
<PackageReference Include="Microsoft.Rest.ClientRuntime" Version="2.3.24" />
<!--Mitigate reported security issues-->
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="7.0.3" />
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="7.0.11" />
<PackageReference Include="Polly" Version="7.2.3" />
<PackageReference Include="Polly.Extensions.Http" Version="3.0.0" />
</ItemGroup>
Expand Down
84 changes: 65 additions & 19 deletions src/Tes/Utilities/PostgresConnectionStringUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,81 @@

using System;
using System.Text;
using Microsoft.Extensions.Options;
using System.Threading.Tasks;
using Azure.Core;
using Tes.Models;

namespace Tes.Utilities
{
public class ConnectionStringUtility
public class PostgresConnectionStringUtility
{
public string GetPostgresConnectionString(IOptions<PostgreSqlOptions> options)
private const string azureDatabaseForPostgresqlScope = "https://ossrdbms-aad.database.windows.net/.default";
private readonly string connectionString = null!;
private readonly TokenCredential tokenCredential = null!;
public bool UseManagedIdentity { get; set; }

public PostgresConnectionStringUtility(PostgreSqlOptions options, TokenCredential tokenCredential)
{
this.tokenCredential = tokenCredential;
connectionString = InternalGetConnectionString(options);
UseManagedIdentity = options.UseManagedIdentity;
}

public async Task<string> GetConnectionString()
{
ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerName, nameof(options.Value.ServerName));
ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerNameSuffix, nameof(options.Value.ServerNameSuffix));
ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerPort, nameof(options.Value.ServerPort));
ArgumentException.ThrowIfNullOrEmpty(options.Value.ServerSslMode, nameof(options.Value.ServerSslMode));
ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseName, nameof(options.Value.DatabaseName));
ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseUserLogin, nameof(options.Value.DatabaseUserLogin));
ArgumentException.ThrowIfNullOrEmpty(options.Value.DatabaseUserPassword, nameof(options.Value.DatabaseUserPassword));

if (options.Value.ServerName.Contains(options.Value.ServerNameSuffix, StringComparison.OrdinalIgnoreCase))
if (UseManagedIdentity)
{
throw new ArgumentException($"'{nameof(options.Value.ServerName)}' should only contain the name of the server like 'myserver' and NOT the full host name like 'myserver{options.Value.ServerNameSuffix}'", nameof(options.Value.ServerName));
// Use AAD managed identity
// https://learn.microsoft.com/en-us/azure/postgresql/single-server/how-to-connect-with-managed-identity
// https://learn.microsoft.com/en-us/azure/postgresql/single-server/concepts-azure-ad-authentication

var accessToken = await tokenCredential.GetTokenAsync(
new TokenRequestContext(scopes: new string[] { azureDatabaseForPostgresqlScope }), System.Threading.CancellationToken.None);

return $"{connectionString}Password={accessToken.Token};";
}

return connectionString;
}

private string InternalGetConnectionString(PostgreSqlOptions options)
{
ArgumentException.ThrowIfNullOrEmpty(options.ServerName, nameof(options.ServerName));
ArgumentException.ThrowIfNullOrEmpty(options.ServerNameSuffix, nameof(options.ServerNameSuffix));
ArgumentException.ThrowIfNullOrEmpty(options.ServerPort, nameof(options.ServerPort));
ArgumentException.ThrowIfNullOrEmpty(options.ServerSslMode, nameof(options.ServerSslMode));
ArgumentException.ThrowIfNullOrEmpty(options.DatabaseName, nameof(options.DatabaseName));
ArgumentException.ThrowIfNullOrEmpty(options.DatabaseUserLogin, nameof(options.DatabaseUserLogin));

if (!options.UseManagedIdentity)
{
// Ensure password is set if NOT using Managed Identity
ArgumentException.ThrowIfNullOrEmpty(options.DatabaseUserPassword, nameof(options.DatabaseUserPassword));
}

if (options.UseManagedIdentity && !string.IsNullOrWhiteSpace(options.DatabaseUserPassword))
{
// throw if password IS set when using Managed Identity
throw new ArgumentException("DatabaseUserPassword shall not be set if UseManagedIdentity is true");
}

if (options.ServerName.Contains(options.ServerNameSuffix, StringComparison.OrdinalIgnoreCase))
{
throw new ArgumentException($"'{nameof(options.ServerName)}' should only contain the name of the server like 'myserver' and NOT the full host name like 'myserver{options.ServerNameSuffix}'", nameof(options.ServerName));
}

var connectionStringBuilder = new StringBuilder();
connectionStringBuilder.Append($"Server={options.Value.ServerName}{options.Value.ServerNameSuffix};");
connectionStringBuilder.Append($"Database={options.Value.DatabaseName};");
connectionStringBuilder.Append($"Port={options.Value.ServerPort};");
connectionStringBuilder.Append($"User Id={options.Value.DatabaseUserLogin};");
connectionStringBuilder.Append($"Password={options.Value.DatabaseUserPassword};");
connectionStringBuilder.Append($"SSL Mode={options.Value.ServerSslMode};");
connectionStringBuilder.Append($"Server={options.ServerName}{options.ServerNameSuffix};");
connectionStringBuilder.Append($"Database={options.DatabaseName};");
connectionStringBuilder.Append($"Port={options.ServerPort};");
connectionStringBuilder.Append($"SSL Mode={options.ServerSslMode};");
connectionStringBuilder.Append($"User Id={options.DatabaseUserLogin};");

if (!options.UseManagedIdentity)
{
connectionStringBuilder.Append($"Password={options.DatabaseUserPassword};");
}

return connectionStringBuilder.ToString();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ await PostgreSqlTestUtility.CreateTestDbAsync(
DatabaseUserPassword = adminPw
};

var optionsMock = new Mock<IOptions<PostgreSqlOptions>>();
optionsMock.Setup(x => x.Value).Returns(options);
var connectionString = new ConnectionStringUtility().GetPostgresConnectionString(optionsMock.Object);
repository = new TesTaskPostgreSqlRepository(() => new TesDbContext(connectionString));
var optionsMock = new Mock<PostgreSqlOptions>();
var connectionString = new PostgresConnectionStringUtility(optionsMock.Object, null);
repository = new TesTaskPostgreSqlRepository();
Console.WriteLine("Creation complete.");
}

Expand Down
6 changes: 5 additions & 1 deletion src/TesApi.Web/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using Tes.ApiClients.Options;
using Tes.Models;
using Tes.Repository;
using Tes.Utilities;
using TesApi.Filters;
using TesApi.Web.Management;
using TesApi.Web.Management.Batch;
Expand Down Expand Up @@ -77,6 +78,10 @@ public void ConfigureServices(IServiceCollection services)
.Configure<MarthaOptions>(configuration.GetSection(MarthaOptions.SectionName))

.AddMemoryCache(o => o.ExpirationScanFrequency = TimeSpan.FromHours(12))

.AddSingleton<TokenCredential, DefaultAzureCredential>()
.AddSingleton<PostgresConnectionStringUtility>()
.AddDbContext<TesDbContext>(ServiceLifetime.Scoped)
.AddSingleton<ICache<TesTaskDatabaseItem>, TesRepositoryCache<TesTaskDatabaseItem>>()
.AddSingleton<TesTaskPostgreSqlRepository>()
.AddSingleton<AzureProxy>()
Expand Down Expand Up @@ -108,7 +113,6 @@ public void ConfigureServices(IServiceCollection services)
.AddSingleton<AzureManagementClientsFactory>()
.AddSingleton<ConfigurationUtils>()
.AddSingleton<IAllowedVmSizesService, AllowedVmSizesService>()
.AddSingleton<TokenCredential>(s => new DefaultAzureCredential())
.AddSingleton<TaskToNodeTaskConverter>()
.AddSingleton<TaskExecutionScriptingManager>()
.AddTransient<BatchNodeScriptBuilder>()
Expand Down
Loading
Loading