using System.Collections.Immutable; using Connected.Data.Schema; using Connected.Data.Sharding; using Connected.Entities.Storage; using Connected.Middleware; using Connected.ServiceModel; using Connected.ServiceModel.Transactions; namespace Connected.Data.Storage; internal sealed class ConnectionProvider : IConnectionProvider, IAsyncDisposable, IDisposable { private List _connections; public ConnectionProvider(IContext context, IMiddlewareService middleware, ITransactionContext transaction) { Context = context; Middleware = middleware; TransactionService = transaction; TransactionService.StateChanged += OnTransactionStateChanged; _connections = new(); } public IContext Context { get; } public IMiddlewareService Middleware { get; } private ITransactionContext TransactionService { get; } private List Connections => _connections; public StorageConnectionMode Mode { get; set; } = StorageConnectionMode.Shared; public async Task> Open(StorageContextArgs args) { /* * Isolated transactions are supported only during active TransactionService state. */ if (TransactionService.State == MiddlewareTransactionState.Completed) Mode = StorageConnectionMode.Isolated; return args is ISchemaSynchronizationContext context ? ResolveSingle(context) : await ResolveMultiple(args); } /// /// This method is called if the supplied arguments already provided connection type on which they will perform operations. /// /// /// This method is usually called when synchronizing entities because the synhronization process already knows what connections /// should be used. /// /// /// /// /// private ImmutableList ResolveSingle(ISchemaSynchronizationContext args) { return new List { EnsureConnection(args.ConnectionType, args.ConnectionString) }.ToImmutableList(); } private async Task> ResolveMultiple(StorageContextArgs args) { var connectionMiddleware = await ResolveConnectionMiddleware(); /* * Check if sharding is supported on the entity */ var middleware = await Middleware.Query(); IShardingMiddleware? shardingMiddleware = null; foreach (var m in middleware) { if (m.SupportsEntity(typeof(TEntity))) { shardingMiddleware = m; break; } } var result = new List { /* * Default connection is always used regardless of sharding support */ EnsureConnection(connectionMiddleware.ConnectionType, connectionMiddleware.DefaultConnectionString) }; if (shardingMiddleware is not null) { foreach (var node in await shardingMiddleware.ProvideNodes(args.Operation)) { /* * Sharding is only supported on connection of the same type. */ if (!string.Equals(node.ConnectionType, connectionMiddleware.ConnectionType.FullName, StringComparison.Ordinal)) throw new ArgumentException("Sharding connection types mismatch ({connectionType})", node.ConnectionType); if (Type.GetType(node.ConnectionType) is not Type connectionType) throw new NullReferenceException(node.ConnectionType); result.Add(EnsureConnection(connectionType, node.ConnectionString)); } } return result.ToImmutableList(); } private IStorageConnection EnsureConnection(Type connectionType, string connectionString) { if (Mode == StorageConnectionMode.Shared && Connections.FirstOrDefault(f => f.GetType() == connectionType && string.Equals(f.ConnectionString, connectionString, StringComparison.OrdinalIgnoreCase)) is IStorageConnection existing) { return existing; } else return CreateConnection(connectionType, connectionString, Mode); } private IStorageConnection CreateConnection(Type connectionType, string connectionString, StorageConnectionMode behavior) { if (Context.GetService(connectionType) is not IStorageConnection newConnection) throw new NullReferenceException(connectionType.Name); newConnection.Initialize(new StorageConnectionArgs(connectionString, behavior)); if (behavior == StorageConnectionMode.Shared) Connections.Add(newConnection); return newConnection; } private async Task ResolveConnectionMiddleware() { var middleware = await Middleware.Query(); foreach (var m in middleware) { if (await m.IsEntitySupported(typeof(TEntity))) return m; } throw new NullReferenceException(nameof(ResolveConnectionMiddleware)); } private async void OnTransactionStateChanged(object? sender, EventArgs e) { if (TransactionService.State == MiddlewareTransactionState.Committing) await Commit(); else if (TransactionService.State == MiddlewareTransactionState.Reverting) await Rollback(); } private async Task Commit() { foreach (var connection in Connections) await connection.Commit(); } private async Task Rollback() { foreach (var connection in Connections) await connection.Rollback(); } public async ValueTask DisposeAsync() { if (_connections is not null) { foreach (var connection in _connections) await connection.DisposeAsync().ConfigureAwait(false); _connections = null; } Dispose(false); GC.SuppressFinalize(this); } private void Dispose(bool disposing) { if (disposing) { if (_connections is not null) { foreach (var connection in _connections) connection.Dispose(); _connections = null; } } } public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } }