diff --git a/Connected.Caching/Annotations/CacheKeyAttribute.cs b/Connected.Caching/Annotations/CacheKeyAttribute.cs new file mode 100644 index 0000000..9cede01 --- /dev/null +++ b/Connected.Caching/Annotations/CacheKeyAttribute.cs @@ -0,0 +1,6 @@ +namespace Connected.Caching.Annotations; + +[AttributeUsage(AttributeTargets.Property)] +public sealed class CacheKeyAttribute : Attribute +{ +} diff --git a/Connected.Caching/Cache.cs b/Connected.Caching/Cache.cs new file mode 100644 index 0000000..dd1764b --- /dev/null +++ b/Connected.Caching/Cache.cs @@ -0,0 +1,432 @@ +using System.Collections.Concurrent; +using System.Collections.Immutable; +using Connected.Interop; + +namespace Connected.Caching; + +internal abstract class Cache : ICache +{ + private bool _disposedValue; + private readonly ConcurrentDictionary _items; + private readonly Task _scavenger; + private readonly CancellationTokenSource _cancel = new(); + + public event CacheInvalidateHandler? Invalidating; + public event CacheInvalidateHandler? Invalidated; + + public Cache() + { + _scavenger = new Task(OnScaveging, _cancel.Token, TaskCreationOptions.LongRunning); + _items = new ConcurrentDictionary(); + + _scavenger.Start(); + } + private ConcurrentDictionary Items => _items; + private CancellationTokenSource Cancel => _cancel; + + private void OnScaveging() + { + var token = Cancel.Token; + + while (!token.IsCancellationRequested) + { + try + { + foreach (var i in Items) + i.Value.Scave(); + + var empties = Items.Where(f => f.Value.Count == 0).Select(f => f.Key); + + foreach (var i in empties) + Items.TryRemove(i, out _); + + token.WaitHandle.WaitOne(TimeSpan.FromMinutes(1)); + } + catch { } + } + } + public virtual bool IsEmpty(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Any(); + + return true; + } + + public virtual bool Exists(string key) + { + return Items.ContainsKey(key); + } + + public void CreateKey(string key) + { + if (Exists(key)) + return; + + Items.TryAdd(key, new Entries()); + } + + public IEnumerator? GetEnumerator(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.GetEnumerator(); + + return new List().GetEnumerator(); + } + public virtual ImmutableList? All(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.All(); + + return default; + } + + public int Count(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Count; + + return 0; + } + + public virtual T? Get(string key, Func predicate) + { + if (Items.TryGetValue(key, out Entries? value) && value.Get(predicate) is IEntry entry) + return GetValue(entry); + + return default; + } + public virtual async Task Get(string key, Func predicate, Func> retrieve) + { + if (Items.TryGetValue(key, out Entries? value) && value.Get(predicate) is IEntry entry) + return GetValue(entry); + + if (retrieve is null) + return default; + + var options = new EntryOptions(); + T instance = await retrieve(options); + + if (EqualityComparer.Default.Equals(instance, default)) + { + if (!options.AllowNull) + return default; + } + + if (string.IsNullOrWhiteSpace(options.Key)) + throw new SysException(this, SR.ErrCacheKeyNull); + + Set(key, options.Key, instance, options.Duration, options.SlidingExpiration); + + return instance; + } + + public virtual async Task Get(string key, object id, Func>? retrieve) + { + if (Items.TryGetValue(key, out Entries? value) && value.Get(id is null ? null : id.ToString()) is IEntry entry) + return GetValue(entry); + + if (retrieve is null) + return default; + + var options = new EntryOptions + { + Key = id is null ? null : id.ToString() + }; + + T instance = await retrieve(options); + + if (EqualityComparer.Default.Equals(instance, default)) + { + if (!options.AllowNull) + return default; + } + + Set(key, options.Key, instance, options.Duration, options.SlidingExpiration); + + return instance; + } + + internal void ClearCore(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + value.Clear(); + } + + public virtual async Task Clear(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + value.Clear(); + + await Task.CompletedTask; + } + + public virtual T? Get(string key, object id) + { + if (Items.TryGetValue(key, out Entries? value) && value.Get(id is null ? null : id.ToString()) is IEntry entry) + return GetValue(entry); + + return default; + } + + public IEntry? Get(string key, object id) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Get(id is null ? null : id.ToString()); + + return default; + } + + public virtual T? Get(string key, Func predicate) + { + if (Items.TryGetValue(key, out Entries? value) && value.Get(predicate) is IEntry entry) + return GetValue(entry); + + return default; + } + + public virtual T? First(string key) + { + if (Items.TryGetValue(key, out Entries? value) && value.First() is IEntry entry) + return GetValue(entry); + + return default; + } + + public virtual ImmutableList? Where(string key, Func predicate) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Where(predicate); + + return default; + } + + public void CopyTo(string key, object id, IEntry instance) + { + if (!Items.TryGetValue(key, out Entries? value)) + { + value = new Entries(); + + if (!Items.TryAdd(key, value)) + return; + } + + value.Set(id is null ? null : id.ToString(), instance.Instance, instance.Duration, instance.SlidingExpiration); + } + + public virtual T? Set(string key, object id, T? instance) + { + return Set(key, id, instance, TimeSpan.Zero); + } + + public virtual T? Set(string key, object id, T? instance, TimeSpan duration) + { + return Set(key, id, instance, duration, false); + } + + public virtual T? Set(string key, object id, T? instance, TimeSpan duration, bool slidingExpiration) + { + if (!Items.TryGetValue(key, out Entries? value)) + { + value = new Entries(); + + if (!Items.TryAdd(key, value)) + return default; + } + + value.Set(id is null ? null : id.ToString(), instance, duration, slidingExpiration); + + return instance; + } + + internal void RemoveCore(string key, object id) + { + if (id is null) + return; + + if (Items.TryGetValue(key, out Entries? value)) + value.Remove(id.ToString()); + } + public virtual async Task Remove(string key, object id) + { + await Remove(key, id, true); + } + + private async Task Remove(string key, object id, bool removing) + { + if (Items.TryGetValue(key, out Entries? value)) + value.Remove(id is null ? null : id.ToString()); + + if (removing) + await OnRemove(key, id); + } + + protected virtual async Task OnRemove(string key, object id) + { + await Task.CompletedTask; + } + + public async Task Invalidate(string key, object id) + { + await InvalidateCore(key, id, false); + } + + internal async Task InvalidateCore(string key, object id, bool fromNotification) + { + /* + * we store existing instance but it is not + * removed from the cache yet. This is because other + * threads can access this instance while we are + * retrieving a new version from the source + */ + var existing = Get(key, id); + var args = new CacheEventArgs(id is null ? null : id.ToString(), key); + /* + * this two events invalidate that cache reference. + * note that if no new version exists the existing one + * is still available to other threads. + */ + try + { + Invalidating?.Invoke(args); + } + catch { } + + try + { + if (!fromNotification) + await OnInvalidating(args); + + await OnInvalidate(args); + } + catch { } + /* + * now find out if a new version has been set for the + * specified key + */ + var newInstance = Get(key, id); + /* + * if no existing reference exists there is no need for + * removing anything + */ + if (existing is not null) + { + /* + * we have an existing instance. we are dealing with two possible scenarios: + * - newInstance is null because entity has been deleted + * - newInstance is actually the same instance as the existing which means a new + * version does not exist. In both cases we must remove existing reference because + * at this point it is not valid anymore. + * note that the third case exists: reference has been replaced. in that case there + * is nothing to do because Invalidating events has already replaced reference with a + * new version. + */ + if (newInstance is null) + await Remove(key, id, false); + else if (existing.Equals(newInstance) && args.Behavior == InvalidateBehavior.RemoveSameInstance) + await Remove(key, id, false); + } + + try + { + Invalidated?.Invoke(args); + } + catch { } + } + + protected internal virtual async Task OnInvalidating(CacheEventArgs e) + { + await Task.CompletedTask; + } + + protected virtual async Task OnInvalidate(CacheEventArgs e) + { + await Task.CompletedTask; + } + + private void Clear() + { + foreach (var i in Items) + i.Value.Clear(); + + Items.Clear(); + } + + public virtual async Task?> Remove(string key, Func predicate) + { + if (Items.TryGetValue(key, out Entries? value)) + { + var result = value.Remove(predicate); + + if (result is not null && result.Any()) + await OnRemove(key, result); + } + + return default; + } + + protected virtual async Task OnRemove(string key, ImmutableList ids) + { + await Task.CompletedTask; + } + + public ImmutableList? Keys(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Keys; + + return default; + } + + public ImmutableList Keys() + { + return Items.Keys.ToImmutableList(); + } + + public bool Any(string key) + { + if (Items.TryGetValue(key, out Entries? value)) + return value.Any(); + + return false; + } + + private static T? GetValue(IEntry entry) + { + if (entry is null || entry.Instance is null) + return default; + + if (TypeConversion.TryConvert(entry.Instance, out T? result)) + return result; + + return default; + } + + protected virtual void OnDisposing(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + Cancel.Cancel(); + Clear(); + + if (_scavenger is not null) + { + _cancel.Cancel(); + + if (_scavenger.IsCompleted) + _scavenger.Dispose(); + } + } + + _disposedValue = true; + } + } + + public void Dispose() + { + OnDisposing(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Caching/CacheClient.cs b/Connected.Caching/CacheClient.cs new file mode 100644 index 0000000..77d22df --- /dev/null +++ b/Connected.Caching/CacheClient.cs @@ -0,0 +1,128 @@ +using System.Collections; +using System.Collections.Immutable; + +namespace Connected.Caching; + +public abstract class CacheClient : ICacheClient where TEntry : class +{ + protected CacheClient(ICachingService cachingService, string key) + { + if (cachingService is null) + throw new ArgumentException(nameof(cachingService)); + + CachingService = cachingService; + Key = key; + } + + public string Key { get; } + protected bool IsDisposed { get; set; } + protected async Task Remove(TKey id) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + await CachingService.Remove(Key, id); + } + + protected async Task Remove(Func predicate) + { + await CachingService.Remove(Key, predicate); + } + + protected async Task Refresh(TKey id) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + await CachingService.Invalidate(Key, id); + } + + protected ICachingService CachingService { get; } + + public int Count => CachingService.Count(Key); + protected virtual ICollection? Keys => CachingService.Keys(Key); + + protected virtual Task?> All() + { + return Task.FromResult(CachingService.All(Key)); + } + + protected virtual async Task Get(TKey id, Func> retrieve) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + return await CachingService.Get(Key, id, retrieve); + } + + protected virtual Task Get(TKey id) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + return Task.FromResult(CachingService.Get(Key, id)); + } + + protected virtual Task First() + { + return Task.FromResult(CachingService.First(Key)); + } + + protected virtual async Task Get(Func predicate) + { + return await CachingService.Get(Key, predicate, null); + } + + protected virtual Task?> Where(Func predicate) + { + return Task.FromResult(CachingService.Where(Key, predicate)); + } + + protected virtual void Set(TKey id, TEntry instance) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + CachingService.Set(Key, id, instance); + } + + protected virtual void Set(TKey id, TEntry instance, TimeSpan duration) + { + if (id is null) + throw new ArgumentNullException(nameof(id)); + + CachingService.Set(Key, id, instance, duration); + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + OnDisposing(); + + IsDisposed = true; + } + } + + protected virtual void OnDisposing() + { + + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public virtual IEnumerator GetEnumerator() + { + return CachingService?.GetEnumerator(Key); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } +} diff --git a/Connected.Caching/CacheContext.cs b/Connected.Caching/CacheContext.cs new file mode 100644 index 0000000..1ec5d39 --- /dev/null +++ b/Connected.Caching/CacheContext.cs @@ -0,0 +1,220 @@ +using System.Collections.Immutable; +using System.Reflection; +using Connected.Interop; +using Connected.ServiceModel.Transactions; + +namespace Connected.Caching; + +internal class CacheContext : Cache, ICacheContext +{ + public CacheContext(ICachingService cachingService, ITransactionContext transactionContext) + { + CachingService = cachingService; + TransactionContext = transactionContext; + + TransactionContext.StateChanged += OnTransactionContextStateChanged; + } + + private void OnTransactionContextStateChanged(object? sender, EventArgs e) + { + if (TransactionContext.State == MiddlewareTransactionState.Committing) + Flush(); + } + + private ICachingService CachingService { get; } + private ITransactionContext TransactionContext { get; } + + public override bool Exists(string key) + { + return base.Exists(key) || (CachingService is not null && CachingService.Exists(key)); + } + + public override bool IsEmpty(string key) + { + return base.IsEmpty(key) || (CachingService is not null && CachingService.IsEmpty(key)); + } + + public override ImmutableList? All(string key) + { + return Merge(base.All(key), CachingService?.All(key)); + } + + public override async Task Get(string key, object id, Func>? retrieve) + { + if (!TransactionContext.IsDirty) + { + if (retrieve is null) + return default; + + return await CachingService.Get(key, id, retrieve); + } + + return await base.Get(key, id, (f) => + { + var shared = CachingService.Get(key, id); + + if (shared is not null) + return Task.FromResult(shared); + + if (retrieve is null) + return default; + + return retrieve(f); + }); + } + + public override async Task Get(string key, Func predicate, Func>? retrieve) + { + if (!TransactionContext.IsDirty) + { + if (retrieve is null) + return default; + + return await CachingService.Get(key, predicate, retrieve); + } + + return await base.Get(key, predicate, (f) => + { + var shared = CachingService.Get(key, predicate); + + if (shared is not null) + return Task.FromResult(shared); + + return retrieve(f); + }); + } + + public override T Get(string key, object id) + { + var contextItem = base.Get(key, id); + + if (contextItem is not null) + return contextItem; + + return CachingService.Get(key, id); + } + + public override T Get(string key, Func predicate) + { + var contextItem = base.Get(key, predicate); + + if (contextItem is not null) + return contextItem; + + return CachingService.Get(key, predicate); + } + + public override async Task Clear(string key) + { + await base.Clear(key); + await CachingService.Clear(key); + } + public override T First(string key) + { + if (base.First(key) is T result) + return result; + + return CachingService.First(key); + } + + public override ImmutableList Where(string key, Func predicate) + { + return Merge(base.Where(key, predicate), CachingService.Where(key, predicate)); + } + + public override T Set(string key, object id, T instance) + { + if (!TransactionContext.IsDirty) + return CachingService.Set(key, id, instance); + + return base.Set(key, id, instance); + } + + public override T Set(string key, object id, T instance, TimeSpan duration) + { + if (!TransactionContext.IsDirty) + return CachingService.Set(key, id, instance, duration); + + return base.Set(key, id, instance, duration); + } + + public override T Set(string key, object id, T instance, TimeSpan duration, bool slidingExpiration) + { + if (!TransactionContext.IsDirty) + return CachingService.Set(key, id, instance, duration, slidingExpiration); + + return base.Set(key, id, instance, duration, slidingExpiration); + } + + public override async Task Remove(string key, object id) + { + await base.Remove(key, id); + await CachingService.Remove(key, id); + } + + public override async Task?> Remove(string key, Func predicate) + { + var local = await base.Remove(key, predicate); + var shared = await CachingService.Remove(key, predicate); + + if (local is not null && shared is not null) + return local.AddRange(shared); + + return local is not null ? local : shared; + } + + public void Flush() + { + CachingService.Merge(this); + } + + private static ImmutableList? Merge(ImmutableList? contextItems, ImmutableList? sharedItems) + { + if (contextItems is null) + return sharedItems; + + var result = new List(contextItems); + + foreach (var sharedItem in sharedItems) + { + if (sharedItem is null) + continue; + + if (CachingExtensions.GetCacheKeyProperty(sharedItem) is not PropertyInfo cacheProperty) + { + //Q: should we compare every property and add only if not matched? + contextItems.Add(sharedItem); + continue; + } + + if (FindExisting(cacheProperty.GetValue(sharedItems), contextItems) is null) + result.Add(sharedItem); + } + + return result.ToImmutableList(); + } + + private static T? FindExisting(object value, ImmutableList items) + { + if (items is null || items.IsEmpty) + return default; + + if (CachingExtensions.GetCacheKeyProperty(items[0]) is not PropertyInfo cacheProperty) + return default; + + foreach (var item in items) + { + var id = cacheProperty.GetValue(item); + + if (TypeComparer.Compare(id, value)) + return item; + } + + return default; + } + + protected override async Task OnInvalidate(CacheEventArgs e) + { + await CachingService.Invalidate(e.Key, e.Id); + } +} \ No newline at end of file diff --git a/Connected.Caching/CacheEventArgs.cs b/Connected.Caching/CacheEventArgs.cs new file mode 100644 index 0000000..d3fbdcc --- /dev/null +++ b/Connected.Caching/CacheEventArgs.cs @@ -0,0 +1,27 @@ +namespace Connected.Caching; + +public enum InvalidateBehavior : byte +{ + RemoveSameInstance = 1, + KeepSameInstance = 2 +} +public class CacheEventArgs : EventArgs +{ + public CacheEventArgs(string id, string key) + { + Key = key; + Id = id; + } + + public CacheEventArgs(string id, string key, InvalidateBehavior behavior) + { + Key = key; + Id = id; + Behavior = behavior; + } + + public string Id { get; init; } + public string Key { get; init; } + + public InvalidateBehavior Behavior { get; set; } = InvalidateBehavior.RemoveSameInstance; +} \ No newline at end of file diff --git a/Connected.Caching/CacheNotificationArgs.cs b/Connected.Caching/CacheNotificationArgs.cs new file mode 100644 index 0000000..6c24166 --- /dev/null +++ b/Connected.Caching/CacheNotificationArgs.cs @@ -0,0 +1,16 @@ +namespace Connected.Caching; + +public class CacheNotificationArgs +{ + public CacheNotificationArgs(string method) + { + if (string.IsNullOrWhiteSpace(method)) + throw new ArgumentException(null, nameof(method)); + + Method = method; + } + + public string? Key { get; init; } + public List? Ids { get; init; } + public string Method { get; } +} diff --git a/Connected.Caching/CachingExtensions.cs b/Connected.Caching/CachingExtensions.cs new file mode 100644 index 0000000..0c5f24d --- /dev/null +++ b/Connected.Caching/CachingExtensions.cs @@ -0,0 +1,13 @@ +using System.Reflection; +using Connected.Caching.Annotations; +using Connected.Interop; + +namespace Connected.Caching; + +public static class CachingExtensions +{ + public static PropertyInfo? GetCacheKeyProperty(object instance) + { + return Properties.GetPropertyAttribute(instance); + } +} diff --git a/Connected.Caching/CachingService.cs b/Connected.Caching/CachingService.cs new file mode 100644 index 0000000..ea372ff --- /dev/null +++ b/Connected.Caching/CachingService.cs @@ -0,0 +1,151 @@ +using System.Collections.Immutable; +using Connected.Caching.Net; +using Connected.Net.Server; + +namespace Connected.Caching; + +internal sealed class CachingService : MemoryCache, ICachingService, IDisposable, IAsyncDisposable +{ + public CachingService(IEndpointServer server, CacheServer state, CacheServerConnection backplaneClient) + { + if (server is null) + throw new ArgumentException(null, nameof(server)); + + if (state is null) + throw new ArgumentException(null, nameof(state)); + + BackplaneClient = backplaneClient; + + Server = server; + BackplaneServer = state; + + server.Changed += OnServerChanged; + server.Initialized += OnServerInitialized; + BackplaneServer.Received += OnReceived; + BackplaneClient.Received += OnReceived; + } + + private CacheServerConnection BackplaneClient { get; set; } + private CacheServer BackplaneServer { get; } + private IEndpointServer Server { get; } + + private async void OnServerInitialized(object? sender, EventArgs e) + { + await Initialize(); + } + + public async Task Initialize() + { + await BackplaneClient.Disconnect(); + + try + { + if (!await Server.IsServer()) + { + await BackplaneClient.Initialize(Server.ServerUrl); + await BackplaneClient.Connect(); + } + } + catch + { + // Server probably not initalized yet + } + } + + private async void OnReceived(object? sender, CacheNotificationArgs e) + { + if (string.Equals(e.Method, nameof(Clear), StringComparison.Ordinal)) + ClearCore(e.Key); + else if (string.Equals(e.Method, nameof(Remove), StringComparison.Ordinal)) + { + if (e.Ids is not null && e.Ids.Any()) + { + foreach (var id in e.Ids) + RemoveCore(e.Key, id); + } + } + else if (string.Equals(e.Method, nameof(Invalidate), StringComparison.Ordinal)) + { + if (e.Ids is not null && e.Ids.Any()) + { + foreach (var id in e.Ids) + await InvalidateCore(e.Key, id, true); + } + } + } + + private async void OnServerChanged(object? sender, ServerChangedArgs e) + { + await Initialize(); + } + + public override async Task Clear(string key) + { + await base.Clear(key); + + var args = new CacheNotificationArgs(nameof(Clear)) { Key = key }; + + if (await Server.IsServer()) + await BackplaneServer.Send(args); + else + await BackplaneClient.Notify(nameof(Clear), args); + + + } + protected internal override async Task OnInvalidating(CacheEventArgs e) + { + var args = new CacheNotificationArgs(nameof(Invalidate)) + { + Ids = new List { e.Id }, + Key = e.Key + }; + + if (await Server.IsServer()) + await BackplaneServer.Send(args); + else + await BackplaneClient.Notify(nameof(Invalidate), args); + } + + protected override async Task OnRemove(string key, ImmutableList ids) + { + await base.OnRemove(key, ids); + + var args = new CacheNotificationArgs(nameof(Remove)) + { + Ids = ids.ToList(), + Key = key + }; + + if (await Server.IsServer()) + await BackplaneServer.Send(args); + else + await BackplaneClient.Notify(nameof(Remove), args); + } + + protected override async Task OnRemove(string key, object? id) + { + await base.OnRemove(key, id); + + var ids = new List(); + + if (id is not null) + ids.Add(id.ToString()); + + await OnRemove(key, ids.ToImmutableList()); + } + + public async ValueTask DisposeAsync() + { + Server.Changed -= OnServerChanged; + Server.Initialized -= OnServerInitialized; + } + + + protected override void OnDisposing(bool disposing) + { + Server.Changed -= OnServerChanged; + Server.Initialized -= OnServerInitialized; + + base.OnDisposing(disposing); + } +} \ No newline at end of file diff --git a/Connected.Caching/CachingStartup.cs b/Connected.Caching/CachingStartup.cs new file mode 100644 index 0000000..fb999f4 --- /dev/null +++ b/Connected.Caching/CachingStartup.cs @@ -0,0 +1,32 @@ +using Connected.Annotations; +using Connected.Caching.Net; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Caching; + +internal class CachingStartup : Startup +{ + public const string CachingHub = "/caching"; + + protected override void OnConfigure(WebApplication app) + { + app.MapHub(CachingHub); + } + + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(CacheServer)); + services.AddSingleton(typeof(CacheServerConnection)); + services.AddSingleton(typeof(ICachingService), typeof(CachingService)); + services.AddScoped(typeof(ICacheContext), typeof(CacheContext)); + } + + protected override async Task OnInitialize(Dictionary args) + { + if (Services is not null && Services.GetService() is ICachingService service) + await service.Initialize(); + } +} diff --git a/Connected.Caching/Connected.Caching.csproj b/Connected.Caching/Connected.Caching.csproj new file mode 100644 index 0000000..c06d127 --- /dev/null +++ b/Connected.Caching/Connected.Caching.csproj @@ -0,0 +1,36 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Caching/Entries.cs b/Connected.Caching/Entries.cs new file mode 100644 index 0000000..01d9b69 --- /dev/null +++ b/Connected.Caching/Entries.cs @@ -0,0 +1,260 @@ +using System.Collections.Concurrent; +using System.Collections.Immutable; +using Connected.Interop; + +namespace Connected.Caching; + +internal class Entries +{ + private readonly Lazy> _items = new(); + + private ConcurrentDictionary Items => _items.Value; + public ImmutableList Keys => Items.Keys.ToImmutableList(); + public int Count => Items.Count; + + public bool Any() + { + return Items.Any(); + } + + public void Scave() + { + var expired = new HashSet(); + + foreach (var i in Items) + { + var r = i.Value; + + if (r is null || r.Expired) + expired.Add(i.Key); + } + + foreach (var i in expired) + Remove(i); + } + + public ImmutableList All() + { + var r = new List(); + var expired = Items.Where(f => f.Value.Expired); + + foreach (var e in expired) + Remove(e.Value.Id); + + var instances = Items.Select(f => f.Value.Instance); + + foreach (var i in instances) + { + if (TypeConversion.TryConvert(i, out T? result) && result is not null) + r.Add(result); + } + + return r.ToImmutableList(); + } + + public void Remove(string key) + { + if (Items.IsEmpty) + return; + + if (Items.TryRemove(key, out IEntry? v)) + v.Dispose(); + } + + public void Set(string key, object? instance, TimeSpan duration, bool slidingExpiration) + { + Items[key] = new Entry(key, instance, duration, slidingExpiration); + } + + public IEnumerator GetEnumerator() + { + return new EntryEnumerator(Items); + } + public IEntry? Get(string key) + { + return Find(key); + } + + public IEntry? First() + { + if (!Any()) + return default; + + return Items.First().Value; + } + + public IEntry? Get(Func predicate) + { + return Find(predicate); + } + + public IEntry? Get(Func predicate) + { + return Find(predicate); + } + + public ImmutableList? Remove(Func predicate) + { + if (Where(predicate) is not ImmutableList ds || ds.IsEmpty) + return default; + + var result = new HashSet(); + + foreach (var i in ds) + { + var key = Items.FirstOrDefault(f => InstanceEquals(f.Value?.Instance, i)).Key; + + RemoveInternal(key); + + result.Add(key); + } + + return result.ToImmutableList(); + } + + public ImmutableList? Where(Func predicate) + { + var values = Items.Select(f => f.Value.Instance).Cast(); + + if (values is null || !values.Any()) + return default; + + var filtered = values.Where(predicate); + + if (filtered is null || !filtered.Any()) + return default; + + var r = new List(); + + foreach (var i in filtered) + { + var ce = Items.FirstOrDefault(f => InstanceEquals(f.Value?.Instance, i)); + + if (ce.Value is null) + continue; + + if (ce.Value.Expired) + { + RemoveInternal(ce.Value.Id); + continue; + } + + ce.Value.Hit(); + r.Add(i); + } + + return r.ToImmutableList(); + } + + private void RemoveInternal(string key) + { + if (Items.TryRemove(key, out IEntry? d)) + d.Dispose(); + } + + private IEntry? Find(Func predicate) + { + var instances = Items.Select(f => f.Value?.Instance).Cast(); + + if (instances is null || !instances.Any()) + return default; + + if (instances.FirstOrDefault(predicate) is not T instance) + return default; + + var ce = Items.Values.FirstOrDefault(f => InstanceEquals(f.Instance, instance)); + + if (ce is null) + return default; + + if (ce.Expired) + { + RemoveInternal(ce.Id); + return default; + } + + ce.Hit(); + + return ce; + } + + private IEntry? Find(Func predicate) + { + var instances = Items.Select(f => f.Value?.Instance).Cast(); + + if (instances is null || !instances.Any()) + return default; + + if (instances.FirstOrDefault(predicate) is not T instance) + return default; + + if (Items.Values.FirstOrDefault(f => InstanceEquals(f.Instance, instance)) is not IEntry ce) + return default; + + if (ce.Expired) + { + RemoveInternal(ce.Id); + return default; + } + + ce.Hit(); + + return ce; + } + + private IEntry? Find(string key) + { + if (!Items.ContainsKey(key)) + return default; + + if (Items.TryGetValue(key, out IEntry? d)) + { + if (d.Expired) + { + RemoveInternal(key); + return default; + } + + d.Hit(); + + return d; + } + else + { + RemoveInternal(key); + + return default; + } + } + + public bool Exists(string key) + { + return Find(key) is not null; + } + + public void Clear() + { + foreach (var i in Items) + Remove(i.Key); + } + + private static bool InstanceEquals(object? left, object? right) + { + /* + * TODO: implement IEquality check + */ + if (left is null || right is null) + return false; + + if (left.GetType().IsPrimitive) + return left == right; + + if (left is string && right is string) + return string.Compare(left.ToString(), right.ToString(), false) == 0; + + if (left.GetType().IsValueType && right.GetType().IsValueType) + return left.Equals(right); + + return ReferenceEqualityComparer.Instance.Equals(left, right); + } +} \ No newline at end of file diff --git a/Connected.Caching/Entry.cs b/Connected.Caching/Entry.cs new file mode 100644 index 0000000..f31f648 --- /dev/null +++ b/Connected.Caching/Entry.cs @@ -0,0 +1,35 @@ +namespace Connected.Caching; + +internal class Entry : IEntry +{ + public Entry(string id, object? instance, TimeSpan duration, bool slidingExpiration) + { + Id = id; + Instance = instance; + SlidingExpiration = slidingExpiration; + Duration = duration; + + if (Duration > TimeSpan.Zero) + ExpirationDate = DateTime.UtcNow.AddTicks(duration.Ticks); + } + + public bool SlidingExpiration { get; } + private DateTime ExpirationDate { get; set; } + public TimeSpan Duration { get; set; } + + public object? Instance { get; } + public string Id { get; } + public bool Expired => ExpirationDate != DateTime.MinValue && ExpirationDate < DateTime.UtcNow; + + public void Hit() + { + if (SlidingExpiration && Duration > TimeSpan.Zero) + ExpirationDate = DateTime.UtcNow.AddTicks(Duration.Ticks); + } + + public void Dispose() + { + if (Instance is IDisposable disposable) + disposable.Dispose(); + } +} \ No newline at end of file diff --git a/Connected.Caching/EntryEnumerator.cs b/Connected.Caching/EntryEnumerator.cs new file mode 100644 index 0000000..19f26ab --- /dev/null +++ b/Connected.Caching/EntryEnumerator.cs @@ -0,0 +1,43 @@ +using System.Collections; +using System.Collections.Concurrent; +using Connected.Interop; + +namespace Connected.Caching; + +internal class EntryEnumerator : IEnumerator +{ + public EntryEnumerator(ConcurrentDictionary items) + { + Items = items; + Index = -1; + } + + private int Count => Items.Count; + private int Index { get; set; } + private ConcurrentDictionary Items { get; } + + public T Current => TypeConversion.TryConvert(Items.ElementAt(Index).Value.Instance, out T result) ? result : default; + + object IEnumerator.Current => Current; + + public void Dispose() + { + } + + public bool MoveNext() + { + if (Index < Count - 1) + { + Index++; + + return true; + } + + return false; + } + + public void Reset() + { + Index = -1; + } +} diff --git a/Connected.Caching/EntryOptions.cs b/Connected.Caching/EntryOptions.cs new file mode 100644 index 0000000..808df78 --- /dev/null +++ b/Connected.Caching/EntryOptions.cs @@ -0,0 +1,16 @@ +namespace Connected.Caching; + +public class EntryOptions +{ + public string Key { get; set; } + public string? KeyProperty { get; set; } + public TimeSpan Duration { get; set; } + public bool SlidingExpiration { get; set; } + public bool AllowNull { get; set; } + + public EntryOptions() + { + Duration = TimeSpan.FromMinutes(5); + SlidingExpiration = true; + } +} diff --git a/Connected.Caching/ICache.cs b/Connected.Caching/ICache.cs new file mode 100644 index 0000000..0ba025d --- /dev/null +++ b/Connected.Caching/ICache.cs @@ -0,0 +1,39 @@ +using System.Collections.Immutable; + +namespace Connected.Caching; + +public delegate void CacheInvalidateHandler(CacheEventArgs e); +public interface ICache : IDisposable +{ + event CacheInvalidateHandler? Invalidating; + event CacheInvalidateHandler? Invalidated; + + ImmutableList? All(string key); + + Task Get(string key, object id, Func>? retrieve); + T? Get(string key, object id); + IEntry? Get(string key, object id); + Task Get(string key, Func predicate, Func>? retrieve); + T? Get(string key, Func predicate); + T? First(string key); + IEnumerator? GetEnumerator(string key); + + ImmutableList? Where(string key, Func predicate); + bool Exists(string key); + bool IsEmpty(string key); + void CreateKey(string key); + Task Clear(string key); + + T? Set(string key, object id, T? instance); + T? Set(string key, object id, T? instance, TimeSpan duration); + T? Set(string key, object id, T? instance, TimeSpan duration, bool slidingExpiration); + void CopyTo(string key, object id, IEntry entry); + Task?> Remove(string key, Func predicate); + Task Remove(string key, object id); + Task Invalidate(string key, object id); + + int Count(string key); + bool Any(string key); + ImmutableList? Keys(string key); + ImmutableList? Keys(); +} diff --git a/Connected.Caching/ICacheClient.cs b/Connected.Caching/ICacheClient.cs new file mode 100644 index 0000000..2e14b84 --- /dev/null +++ b/Connected.Caching/ICacheClient.cs @@ -0,0 +1,7 @@ +namespace Connected.Caching; + +public interface ICacheClient : IEnumerable, IDisposable +{ + string Key { get; } + int Count { get; } +} diff --git a/Connected.Caching/ICacheContext.cs b/Connected.Caching/ICacheContext.cs new file mode 100644 index 0000000..4da7c4e --- /dev/null +++ b/Connected.Caching/ICacheContext.cs @@ -0,0 +1,6 @@ +namespace Connected.Caching; + +public interface ICacheContext : ICache +{ + void Flush(); +} diff --git a/Connected.Caching/ICachingService.cs b/Connected.Caching/ICachingService.cs new file mode 100644 index 0000000..404b576 --- /dev/null +++ b/Connected.Caching/ICachingService.cs @@ -0,0 +1,7 @@ +namespace Connected.Caching; + +public interface ICachingService : ICache +{ + void Merge(ICache cache); + Task Initialize(); +} \ No newline at end of file diff --git a/Connected.Caching/IEntry.cs b/Connected.Caching/IEntry.cs new file mode 100644 index 0000000..2f70bcc --- /dev/null +++ b/Connected.Caching/IEntry.cs @@ -0,0 +1,11 @@ +namespace Connected.Caching; + +public interface IEntry : IDisposable +{ + object? Instance { get; } + string Id { get; } + bool Expired { get; } + TimeSpan Duration { get; } + bool SlidingExpiration { get; } + void Hit(); +} diff --git a/Connected.Caching/IMemoryCache.cs b/Connected.Caching/IMemoryCache.cs new file mode 100644 index 0000000..6813866 --- /dev/null +++ b/Connected.Caching/IMemoryCache.cs @@ -0,0 +1,6 @@ +namespace Connected.Caching; + +public interface IMemoryCache : ICache +{ + void Merge(ICache cache); +} diff --git a/Connected.Caching/IStatefulCacheClient.cs b/Connected.Caching/IStatefulCacheClient.cs new file mode 100644 index 0000000..5c5f61f --- /dev/null +++ b/Connected.Caching/IStatefulCacheClient.cs @@ -0,0 +1,6 @@ +namespace Connected.Caching; + +public interface IStatefulCacheClient : ICacheClient +{ + event CacheInvalidateHandler Invalidate; +} diff --git a/Connected.Caching/MemoryCache.cs b/Connected.Caching/MemoryCache.cs new file mode 100644 index 0000000..f835638 --- /dev/null +++ b/Connected.Caching/MemoryCache.cs @@ -0,0 +1,24 @@ +using System.Collections.Immutable; + +namespace Connected.Caching; + +internal class MemoryCache : Cache, IMemoryCache +{ + public void Merge(ICache cache) + { + if (cache.Keys() is not ImmutableList keys) + return; + + foreach (var key in keys) + { + if (cache.Keys(key) is not ImmutableList entryKeys) + continue; + + foreach (var entryKey in entryKeys) + { + if (cache.Get(key, entryKey) is IEntry entry) + CopyTo(key, entryKey, entry); + } + } + } +} \ No newline at end of file diff --git a/Connected.Caching/Net/CacheHub.cs b/Connected.Caching/Net/CacheHub.cs new file mode 100644 index 0000000..ebf3068 --- /dev/null +++ b/Connected.Caching/Net/CacheHub.cs @@ -0,0 +1,11 @@ +using Connected.Net.Hubs; + +namespace Connected.Caching.Net; + +//TODO: implement authorization and add logic to reject connections if not an endpoint server +internal class CacheHub : StatefulHub +{ + public CacheHub(CacheServer server) : base(server) + { + } +} diff --git a/Connected.Caching/Net/CacheServer.cs b/Connected.Caching/Net/CacheServer.cs new file mode 100644 index 0000000..e28ef28 --- /dev/null +++ b/Connected.Caching/Net/CacheServer.cs @@ -0,0 +1,11 @@ +using Connected.Net.Hubs; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Caching.Net; + +internal class CacheServer : Server +{ + public CacheServer(IHubContext hub) : base(hub) + { + } +} diff --git a/Connected.Caching/Net/CacheServerConnection.cs b/Connected.Caching/Net/CacheServerConnection.cs new file mode 100644 index 0000000..a6a9aa3 --- /dev/null +++ b/Connected.Caching/Net/CacheServerConnection.cs @@ -0,0 +1,36 @@ +using Connected.Net.Hubs; +using Connected.Net.Messaging; +using Connected.Net.Server; +using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.Extensions.Logging; + +namespace Connected.Caching.Net; + +internal class CacheServerConnection : ServerConnection +{ + public event EventHandler? Received; + + public CacheServerConnection(IEndpointServer server, ILogger logger) : base(server) + { + Logger = logger; + } + + private ILogger Logger { get; } + + public override async Task Initialize(string hubUrl) + { + await base.Initialize(hubUrl); + + Connection.On("Notify", (a, e) => + { + Connection.InvokeAsync(nameof(CacheHub.Acknowledge), a); + + Received?.Invoke(this, e); + }); + + Connection.On("Exception", (e) => + { + Logger.LogError("Caching hub exception: {message}", e.Message); + }); + } +} diff --git a/Connected.Caching/Net/CacheWorker.cs b/Connected.Caching/Net/CacheWorker.cs new file mode 100644 index 0000000..2f156d0 --- /dev/null +++ b/Connected.Caching/Net/CacheWorker.cs @@ -0,0 +1,11 @@ +using Connected.Net.Hubs; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Caching.Net; + +internal sealed class CacheWorker : ServerWorker +{ + public CacheWorker(CacheServer server, IHubContext hub) : base(server, hub) + { + } +} diff --git a/Connected.Caching/SR.Designer.cs b/Connected.Caching/SR.Designer.cs new file mode 100644 index 0000000..cedd039 --- /dev/null +++ b/Connected.Caching/SR.Designer.cs @@ -0,0 +1,72 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Caching { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Caching.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Cache Key property not set. Please set 'Key' property before returning value from cache retrieve handler or set CacheKeyAttribute on at least one property.. + /// + internal static string ErrCacheKeyNull { + get { + return ResourceManager.GetString("ErrCacheKeyNull", resourceCulture); + } + } + } +} diff --git a/Connected.Caching/SR.resx b/Connected.Caching/SR.resx new file mode 100644 index 0000000..b1e3d01 --- /dev/null +++ b/Connected.Caching/SR.resx @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cache Key property not set. Please set 'Key' property before returning value from cache retrieve handler or set CacheKeyAttribute on at least one property. + + \ No newline at end of file diff --git a/Connected.Caching/StatefulCacheClient.cs b/Connected.Caching/StatefulCacheClient.cs new file mode 100644 index 0000000..83ac776 --- /dev/null +++ b/Connected.Caching/StatefulCacheClient.cs @@ -0,0 +1,113 @@ +using System.Collections.Immutable; +using System.Globalization; +using Connected.Interop; +using Connected.Threading; + +namespace Connected.Caching; + +public abstract class StatefulCacheClient : CacheClient, IStatefulCacheClient where TEntry : class +{ + public event CacheInvalidateHandler Invalidate; + protected StatefulCacheClient(ICachingService cachingService, string key) : base(cachingService, key) + { + Locker = new AsyncLockerSlim(); + + CachingService.Invalidating += OnInvalidate; + } + private AsyncLockerSlim? Locker { get; set; } + protected virtual InvalidateBehavior InvalidateBehavior { get; } = InvalidateBehavior.KeepSameInstance; + private bool Initialized { get; set; } + private async void OnInvalidate(CacheEventArgs e) + { + if (string.Equals(e.Key, Key, StringComparison.OrdinalIgnoreCase)) + { + if (Initialized) + await OnInvalidate(TypeConversion.Convert(e.Id, CultureInfo.InvariantCulture)); + + Invalidate?.Invoke(e); + + e.Behavior = InvalidateBehavior; + } + } + protected virtual async Task OnInvalidate(TKey id) + { + await Task.CompletedTask; + } + protected virtual async Task OnInitializing() + { + await Task.CompletedTask; + } + protected async Task Initialize() + { + if (Initialized || IsDisposed || Locker is null) + return; + + await Locker.LockAsync(async () => + { + if (Initialized || IsDisposed) + return; + + await OnInitializing(); + + Initialized = true; + }); + + if (Initialized) + await OnInitialized(); + } + protected virtual async Task OnInitialized() + { + await Task.CompletedTask; + } + protected override async Task?> All() + { + await Initialize(); + + return base.All().Result; + } + protected override async Task First() + { + await Initialize(); + + return await base.First(); + } + protected override async Task Get(Func predicate) + { + await Initialize(); + + return await base.Get(predicate); + } + protected override async Task Get(TKey id) + { + await Initialize(); + + return await base.Get(id); + } + protected override async Task Get(TKey id, Func> retrieve) + { + await Initialize(); + + return await base.Get(id, retrieve); + } + protected override async Task?> Where(Func predicate) + { + await Initialize(); + + return await base.Where(predicate); + } + protected override void OnDisposing() + { + if (Locker is not null) + { + Locker?.Dispose(); + Locker = null; + } + } + + public override IEnumerator GetEnumerator() + { + AsyncUtils.RunSync(Initialize); + + return base.GetEnumerator(); + } +} \ No newline at end of file diff --git a/Connected.Collections/CollectionExtensions.cs b/Connected.Collections/CollectionExtensions.cs new file mode 100644 index 0000000..d02e287 --- /dev/null +++ b/Connected.Collections/CollectionExtensions.cs @@ -0,0 +1,76 @@ +using System.Collections.Immutable; +using System.Reflection; +using Connected.Annotations; + +namespace Connected.Collections; + +public static class CollectionExtensions +{ + public static ImmutableArray ToImmutableArray(this IEnumerable items, bool performLock) + { + if (!performLock) + return items.ToImmutableArray(); + + lock (items) + return items.ToImmutableArray(); + } + + public static ImmutableList ToImmutableList(this IEnumerable items, bool performLock) + { + if (!performLock) + return items.ToImmutableList(); + + lock (items) + return items.ToImmutableList(); + } + + public static void SortByOrdinal(this List items) + { + items.Sort((left, right) => + { + var leftOrdinal = left is Type lt ? lt.GetCustomAttribute() : left?.GetType().GetCustomAttribute(); + var rightOrdinal = right is Type rt ? rt.GetCustomAttribute() : right?.GetType().GetCustomAttribute(); + + if (leftOrdinal is null && rightOrdinal is null) + return 0; + + if (leftOrdinal is not null && rightOrdinal is null) + return -1; + + if (leftOrdinal is null && rightOrdinal is not null) + return 1; + + if (leftOrdinal?.Value == rightOrdinal?.Value) + return 0; + else if (leftOrdinal?.Value < rightOrdinal?.Value) + return 1; + else + return -1; + }); + } + + public static void SortByPriority(this List items) + { + items.Sort((left, right) => + { + var leftPriority = left is Type lt ? lt.GetCustomAttribute() : left?.GetType().GetCustomAttribute(); + var rightPriority = right is Type rt ? rt.GetCustomAttribute() : right?.GetType().GetCustomAttribute(); + + if (leftPriority is null && rightPriority is null) + return 0; + + if (leftPriority is not null && rightPriority is null) + return -1; + + if (leftPriority is null && rightPriority is not null) + return 1; + + if (leftPriority?.Value == rightPriority?.Value) + return 0; + else if (leftPriority?.Value > rightPriority?.Value) + return -1; + else + return 1; + }); + } +} diff --git a/Connected.Collections/CollectionRoutes.cs b/Connected.Collections/CollectionRoutes.cs new file mode 100644 index 0000000..b776866 --- /dev/null +++ b/Connected.Collections/CollectionRoutes.cs @@ -0,0 +1,6 @@ +namespace Connected.Collections; + +public static class CollectionRoutes +{ + public const string Queue = "/sys/queue"; +} diff --git a/Connected.Collections/CollectionsStartup.cs b/Connected.Collections/CollectionsStartup.cs new file mode 100644 index 0000000..f1bf2c0 --- /dev/null +++ b/Connected.Collections/CollectionsStartup.cs @@ -0,0 +1,16 @@ +using Connected.Annotations; +using Microsoft.AspNetCore.Builder; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Collections; + +internal class CollectionsStartup : Startup +{ + public static WebApplication? Application { get; private set; } + + protected override void OnConfigure(WebApplication app) + { + Application = app; + } +} diff --git a/Connected.Collections/Concurrent/Dispatcher.cs b/Connected.Collections/Concurrent/Dispatcher.cs new file mode 100644 index 0000000..2636dba --- /dev/null +++ b/Connected.Collections/Concurrent/Dispatcher.cs @@ -0,0 +1,177 @@ +using System.Collections.Concurrent; +using Connected; +using Microsoft.Extensions.DependencyInjection; + +namespace Connected.Collections.Concurrent; + +public abstract class Dispatcher : IDispatcher + where TJob : IDispatcherJob +{ + private CancellationTokenSource _tokenSource; + protected Dispatcher(int size) + { + WorkerSize = size; + + _tokenSource = new(); + + Queue = new(); + Jobs = new(); + QueuedDispatchers = new(); + } + + public CancellationToken CancellationToken => _tokenSource.Token; + private ConcurrentQueue Queue { get; set; } + private List> Jobs { get; set; } + protected bool IsDisposed { get; private set; } + private int WorkerSize { get; } + public int Available => Math.Max(0, WorkerSize * 4 - Queue.Count - QueuedDispatchers.Sum(f => f.Value.Count)); + private ConcurrentDictionary> QueuedDispatchers { get; set; } + public DispatcherProcessBehavior Behavior => DispatcherProcessBehavior.Parallel; + + public void Cancel() + { + _tokenSource?.Cancel(); + } + + public bool Dequeue(out TArgs? item) + { + return Queue.TryDequeue(out item); + } + public bool Enqueue(string queue, TArgs item) + { + if (EnsureDispatcher(queue) is not QueuedDispatcher dispatcher) + throw new SysException(this, $"{SR.ErrCreateQueuedDispatcher} ({queue})"); + + return dispatcher.Enqueue(item); + } + public bool Enqueue(TArgs item) + { + Queue.Enqueue(item); + + if (Jobs.Count < WorkerSize) + { + /* + * Dispatcher jobs should be transient so it's safe to request a service from the root collection. + */ + if (CollectionsStartup.Application.Services.GetService() is not DispatcherJob job) + throw new NullReferenceException($"{SR.ErrCreateService} ({typeof(DispatcherJob).Name})"); + + job.Completed += OnCompleted; + + lock (Jobs) + { + Jobs.Add(job); + } + + job.Run(Queue, CancellationToken); + } + + return true; + } + + private void OnCompleted(object? sender, EventArgs e) + { + if (sender is not DispatcherJob job) + return; + + if (Queue.IsEmpty) + { + lock (Jobs) + { + Jobs.Remove(job); + } + + job.Dispose(); + } + else + job.Run(Queue, CancellationToken); + } + + private QueuedDispatcher? EnsureDispatcher(string queueName) + { + if (QueuedDispatchers.TryGetValue(queueName, out QueuedDispatcher? result)) + return result; + + result = new QueuedDispatcher(this, queueName); + + result.Completed += OnQueuedCompleted; + + if (!QueuedDispatchers.TryAdd(queueName, result)) + { + result.Completed -= OnQueuedCompleted; + + if (QueuedDispatchers.TryGetValue(queueName, out QueuedDispatcher? retryResult)) + return retryResult; + else + return default; + } + + return result; + } + private void OnQueuedCompleted(object? sender, EventArgs e) + { + if (sender is not QueuedDispatcher dispatcher) + return; + + if (dispatcher.Count > 0) + return; + + QueuedDispatchers.Remove(dispatcher.QueueName, out _); + + dispatcher.Dispose(); + } + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + if (_tokenSource is not null) + { + if (!_tokenSource.IsCancellationRequested) + _tokenSource.Cancel(); + + _tokenSource.Dispose(); + _tokenSource = null; + } + + if (Queue is not null) + { + Queue.Clear(); + Queue = null; + } + + if (Jobs is not null) + { + foreach (var job in Jobs) + job.Dispose(); + + Jobs.Clear(); + Jobs = null; + } + + if (QueuedDispatchers is not null) + { + foreach (var dispatcher in QueuedDispatchers) + dispatcher.Value.Dispose(); + + QueuedDispatchers.Clear(); + Queue = null; + } + } + + IsDisposed = true; + } + } + + protected virtual void OnDisposing() + { + + } + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} + diff --git a/Connected.Collections/Concurrent/DispatcherJob.cs b/Connected.Collections/Concurrent/DispatcherJob.cs new file mode 100644 index 0000000..e3d563e --- /dev/null +++ b/Connected.Collections/Concurrent/DispatcherJob.cs @@ -0,0 +1,103 @@ +using System.Collections.Concurrent; +using System.ComponentModel; +using Connected.Data; + +namespace Connected.Collections.Concurrent; + +/// +/// This class acts as a job unit of the . +/// +/// +public abstract class DispatcherJob : IDispatcherJob, IDisposable +{ + public event EventHandler? Completed; + public bool IsRunning { get; private set; } + protected bool IsDisposed { get; private set; } + private CancellationToken CancellationToken { get; set; } + + internal void Run(ConcurrentQueue queue, CancellationToken cancellationToken) + { + CancellationToken = cancellationToken; + + if (IsRunning) + return; + + Task.Run(async () => + { + IsRunning = true; + TArgs? item = default; + + try + { + while (queue.TryDequeue(out item)) + { + if (item is null) + continue; + + if (item is IPopReceipt pr && pr.NextVisible <= DateTime.UtcNow) + continue; + + await Invoke(item); + + if (cancellationToken.IsCancellationRequested || IsDisposed) + break; + } + } + catch (Exception ex) + { + await HandleException(item, ex); + } + + IsRunning = false; + Completed?.Invoke(this, EventArgs.Empty); + + }, CancellationToken); + } + + private void OnCompleted(object? sender, RunWorkerCompletedEventArgs e) + { + Completed?.Invoke(this, EventArgs.Empty); + } + + public async Task Invoke(TArgs args) + { + await OnInvoke(args, CancellationToken); + } + + protected virtual async Task OnInvoke(TArgs args, CancellationToken cancellationToken) + { + await Task.CompletedTask; + } + + private async Task HandleException(TArgs? args, Exception ex) + { + await OnHandleEception(args, ex); + } + + protected virtual async Task OnHandleEception(TArgs? args, Exception ex) + { + await Task.CompletedTask; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (IsDisposed) + return; + + if (disposing) + OnDisposing(); + + IsDisposed = true; + } + + protected virtual void OnDisposing() + { + + } +} \ No newline at end of file diff --git a/Connected.Collections/Concurrent/IDispatcher.cs b/Connected.Collections/Concurrent/IDispatcher.cs new file mode 100644 index 0000000..dedf2a9 --- /dev/null +++ b/Connected.Collections/Concurrent/IDispatcher.cs @@ -0,0 +1,19 @@ +namespace Connected.Collections.Concurrent; + +public enum DispatcherProcessBehavior +{ + Parallel = 1, + Queued = 2 +} + +public interface IDispatcher : IDisposable + where TJob : IDispatcherJob +{ + bool Dequeue(out TArgs? item); + bool Enqueue(TArgs item); + bool Enqueue(string queue, TArgs item); + DispatcherProcessBehavior Behavior { get; } + + CancellationToken CancellationToken { get; } + void Cancel(); +} diff --git a/Connected.Collections/Concurrent/IDispatcherJob.cs b/Connected.Collections/Concurrent/IDispatcherJob.cs new file mode 100644 index 0000000..2eb66ca --- /dev/null +++ b/Connected.Collections/Concurrent/IDispatcherJob.cs @@ -0,0 +1,7 @@ +namespace Connected.Collections.Concurrent; + +public interface IDispatcherJob : IDisposable +{ + Task Invoke(TArgs args); + bool IsRunning { get; } +} diff --git a/Connected.Collections/Concurrent/QueuedDispatcher.cs b/Connected.Collections/Concurrent/QueuedDispatcher.cs new file mode 100644 index 0000000..54a241e --- /dev/null +++ b/Connected.Collections/Concurrent/QueuedDispatcher.cs @@ -0,0 +1,99 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.DependencyInjection; + +namespace Connected.Collections.Concurrent; + +internal sealed class QueuedDispatcher : IDispatcher + where TJob : IDispatcherJob +{ + public event EventHandler? Completed; + + public QueuedDispatcher(IDispatcher dispatcher, string queueName) + { + Dispatcher = dispatcher; + Queue = new(); + QueueName = queueName; + + /* + * Dispatcher jobs should be transient so it's safe to request a service from the root collection. + */ + if (CollectionsStartup.Application?.Services.GetService>() is not DispatcherJob job) + throw new SysException(this, $"{SR.ErrCreateService} ({typeof(DispatcherJob).Name})"); + + job.Completed += OnCompleted; + + Job = job; + } + + public CancellationToken CancellationToken => Dispatcher.CancellationToken; + public bool IsDisposed { get; private set; } + public DispatcherProcessBehavior Behavior => DispatcherProcessBehavior.Queued; + public string QueueName { get; } + private DispatcherJob Job { get; set; } + private IDispatcher Dispatcher { get; set; } + public int Count => Queue.Count; + private ConcurrentQueue Queue { get; set; } + + public void Cancel() + { + + } + + public bool Dequeue(out TArgs? item) + { + return Queue.TryDequeue(out item); + } + public bool Enqueue(TArgs item) + { + if (IsDisposed) + return false; + + Queue.Enqueue(item); + + if (!Job.IsRunning) + Job.Run(Queue, CancellationToken); + + return true; + } + + private void OnCompleted(object? sender, EventArgs e) + { + if (sender is not DispatcherJob job) + return; + + if (!Queue.IsEmpty) + { + job.Run(Queue, CancellationToken); + return; + } + + Completed?.Invoke(this, EventArgs.Empty); + } + + public bool Enqueue(string queue, TArgs args) + { + return Dispatcher.Enqueue(queue, args); + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + if (Job is not null) + { + Job.Dispose(); + Job = null; + } + } + + IsDisposed = true; + } + } + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Collections/Connected.Collections.csproj b/Connected.Collections/Connected.Collections.csproj new file mode 100644 index 0000000..3896ffa --- /dev/null +++ b/Connected.Collections/Connected.Collections.csproj @@ -0,0 +1,30 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Collections/Iterator.cs b/Connected.Collections/Iterator.cs new file mode 100644 index 0000000..3f44036 --- /dev/null +++ b/Connected.Collections/Iterator.cs @@ -0,0 +1,33 @@ +using Connected.Collections.Iterators; + +namespace Connected.Collections; + +public class Iterator +{ + private IIterator _iterator; + public Iterator(object value) + { + if (value is null) + throw new ArgumentException(nameof(value)); + + if (DictionaryIterator.CanHandle(value)) + _iterator = new DictionaryIterator(value); + else if (ListIterator.CanHandle(value)) + _iterator = new ListIterator(value); + else if (ArrayIterator.CanHandle(value)) + _iterator = new ArrayIterator(value); + } + + public async Task MoveNext(Func> processor) + { + if (!_iterator.MoveNext()) + return false; + + if (await processor(_iterator.Current) is object value) + _iterator.Add(value); + + return true; + } + + public object? Result => _iterator.Result; +} diff --git a/Connected.Collections/Iterators/ArrayIterator.cs b/Connected.Collections/Iterators/ArrayIterator.cs new file mode 100644 index 0000000..87b62e1 --- /dev/null +++ b/Connected.Collections/Iterators/ArrayIterator.cs @@ -0,0 +1,32 @@ +namespace Connected.Collections.Iterators; + +internal class ArrayIterator : IIterator +{ + public ArrayIterator(object value) + { + + } + public object? Result => throw new NotImplementedException(); + + public object Current => throw new NotImplementedException(); + + public static bool CanHandle(object value) + { + return value.GetType().IsArray; + } + + public void Add(object value) + { + throw new NotImplementedException(); + } + + public bool MoveNext() + { + throw new NotImplementedException(); + } + + public void Reset() + { + throw new NotImplementedException(); + } +} diff --git a/Connected.Collections/Iterators/DictionaryIterator.cs b/Connected.Collections/Iterators/DictionaryIterator.cs new file mode 100644 index 0000000..667e741 --- /dev/null +++ b/Connected.Collections/Iterators/DictionaryIterator.cs @@ -0,0 +1,75 @@ +using System.Collections; +using System.Collections.Immutable; +using Connected.Interop; + +namespace Connected.Collections.Iterators; + +internal class DictionaryIterator : IIterator +{ + private readonly object _value; + private IDictionary _result; + public DictionaryIterator(object value) + { + _value = value; + + if (value is not IDictionary dictionary) + return; + + Enumerator = dictionary.GetEnumerator(); + + CreateResult(); + } + + public static bool CanHandle(object value) + { + var arguments = value.GetType().GetGenericArguments(); + var kvp = typeof(KeyValuePair<,>).MakeGenericType(arguments); + var en = typeof(IEnumerable<>).MakeGenericType(kvp); + + return value.GetType() == en; + } + private void CreateResult() + { + if (_value.GetType() == typeof(IImmutableDictionary<,>)) + _result = _value.GetType().MakeGenericType(_value.GetType().GenericTypeArguments).CreateInstance(); + } + + private IDictionaryEnumerator? Enumerator { get; } + public object? Current => Enumerator?.Value; + + public object? Result + { + get + { + if (_value is null) + return null; + + throw new NotImplementedException(); + } + } + + public bool MoveNext() + { + if (Enumerator is null) + return false; + + return Enumerator.MoveNext(); + } + + public void Reset() + { + if (Enumerator is null) + return; + + Enumerator.Reset(); + } + + + public void Add(object value) + { + if (Current is null || Enumerator is null) + throw new InvalidOperationException(SR.ErrIteratorCurrentNull); + + _result.Add(Enumerator.Key, value); + } +} diff --git a/Connected.Collections/Iterators/IIterator.cs b/Connected.Collections/Iterators/IIterator.cs new file mode 100644 index 0000000..b891826 --- /dev/null +++ b/Connected.Collections/Iterators/IIterator.cs @@ -0,0 +1,10 @@ +using System.Collections; + +namespace Connected.Collections.Iterators; + +internal interface IIterator : IEnumerator +{ + void Add(object value); + + object? Result { get; } +} diff --git a/Connected.Collections/Iterators/ListIterator.cs b/Connected.Collections/Iterators/ListIterator.cs new file mode 100644 index 0000000..c936b81 --- /dev/null +++ b/Connected.Collections/Iterators/ListIterator.cs @@ -0,0 +1,36 @@ +namespace Connected.Collections.Iterators; + +internal class ListIterator : IIterator +{ + public ListIterator(object value) + { + + } + + public object? Result => throw new NotImplementedException(); + + public object Current => throw new NotImplementedException(); + + public static bool CanHandle(object value) + { + var arguments = value.GetType().GetGenericArguments(); + var list = typeof(IList<>).MakeGenericType(arguments); + + return value.GetType() == list; + } + + public void Add(object value) + { + throw new NotImplementedException(); + } + + public bool MoveNext() + { + throw new NotImplementedException(); + } + + public void Reset() + { + throw new NotImplementedException(); + } +} diff --git a/Connected.Collections/Queues/IQueueClient.cs b/Connected.Collections/Queues/IQueueClient.cs new file mode 100644 index 0000000..77b40ba --- /dev/null +++ b/Connected.Collections/Queues/IQueueClient.cs @@ -0,0 +1,6 @@ +namespace Connected.Collections.Queues; +public interface IQueueClient : IMiddleware + where TArgs : QueueArgs +{ + Task Invoke(IQueueMessage message, TArgs args); +} diff --git a/Connected.Collections/Queues/IQueueMessage.cs b/Connected.Collections/Queues/IQueueMessage.cs new file mode 100644 index 0000000..17c144e --- /dev/null +++ b/Connected.Collections/Queues/IQueueMessage.cs @@ -0,0 +1,48 @@ +using Connected.Data; + +namespace Connected.Collections.Queues; +/// +/// Represents a single queue message. +/// +/// +/// A queue message represents a unit of queued or deferred work which +/// can be processed distributed anywhere or in any client which +/// has access to the Queue REST service. +/// +public interface IQueueMessage : IPrimaryKey, IPopReceipt +{ + /// + /// Date date and time the queue message was created. + /// + DateTime Created { get; init; } + /// + /// The number of times the clients dequeued the message. + /// + /// + /// There are numerous reasons why queue message gets dequeued multiple + /// times. It could be that not all conditions were met at the time + /// of processing or that queue message was not processed quich enough and + /// its pop receipt expired. In such cases message returns to the queue and + /// waits for the next client to dequeue it. + /// + int DequeueCount { get; init; } + /// + /// The timestamp message was last dequeued. + /// + DateTime? DequeueTimestamp { get; init; } + /// + /// The queue to which the message belongs. + /// + /// + /// Every queue client must specify which queue processes. + /// + string Queue { get; init; } + /// + /// The arguments object which contains information about the message. + /// + /// + /// Most queue messages do have an argument object, mostly providing na id of the + /// entity or record for which processing should be performed. + /// + QueueArgs Arguments { get; init; } +} diff --git a/Connected.Collections/Queues/IQueueService.cs b/Connected.Collections/Queues/IQueueService.cs new file mode 100644 index 0000000..3378986 --- /dev/null +++ b/Connected.Collections/Queues/IQueueService.cs @@ -0,0 +1,34 @@ +using System.Collections.Immutable; +using Connected.Annotations; + +namespace Connected.Collections.Queues; +/// +/// Represents a distributed service for processing queue messages. +/// +/// +/// Queue mechanism is mostly used as an internal logic of processes +/// and resources to offload work from the main thread to achieve better +/// responsiveness of the system. Aggregations and calculations are good +/// examples of queue usage. You should use queue whenever you must +/// perform any kind of work that is not necessary to perform it in a single +/// transaction scope. +/// +[Service] +[ServiceUrl(CollectionRoutes.Queue)] +public interface IQueueService +{ + /// + /// Enqueues the queue message. + /// + /// The type of the arguments used in queue message + /// The arguments containing information about a queue message. + Task Enqueue(TArgs args) + where TClient : IQueueClient + where TArgs : QueueArgs; + /// + /// Dequeues the queue messages based on the provided arguments. + /// + /// The arguments containing information about dequeue criteria. + /// A list of valid queue messages that can be immediatelly processed.S + Task> Dequeue(DequeueArgs args); +} diff --git a/Connected.Collections/Queues/QueueArgs.cs b/Connected.Collections/Queues/QueueArgs.cs new file mode 100644 index 0000000..47809d3 --- /dev/null +++ b/Connected.Collections/Queues/QueueArgs.cs @@ -0,0 +1,44 @@ +using System.ComponentModel.DataAnnotations; +using Connected.Annotations; +using Connected.ServiceModel; + +namespace Connected.Collections.Queues; + +public class QueueArgs : Dto +{ + public QueueArgs() + { + Options = new(); + } + + public EnqueueOptions Options { get; set; } +} + +public class PrimaryKeyQueueArgs : QueueArgs + where TPrimaryKey : notnull +{ + public TPrimaryKey Id { get; set; } = default!; +} + +public sealed class EnqueueOptions +{ + /// + /// The date and time the queue message expires. + /// + /// + /// Queue messages that are not processed until they expire + /// gets automatically deleted by the system. + /// + public DateTime Expire { get; set; } = DateTime.UtcNow.AddHours(48); +} + +public sealed class DequeueArgs : Dto +{ + [NonDefault] + public List Queues { get; set; } = default!; + + [Range(1, int.MaxValue)] + public int MaxCount { get; set; } + + public TimeSpan NextVisible { get; set; } +} diff --git a/Connected.Collections/SR.Designer.cs b/Connected.Collections/SR.Designer.cs new file mode 100644 index 0000000..0d04761 --- /dev/null +++ b/Connected.Collections/SR.Designer.cs @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Collections { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Collections.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Cannot create queued dispatcher. + /// + internal static string ErrCreateQueuedDispatcher { + get { + return ResourceManager.GetString("ErrCreateQueuedDispatcher", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot create service instance. + /// + internal static string ErrCreateService { + get { + return ResourceManager.GetString("ErrCreateService", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Iterator does not have current value. + /// + internal static string ErrIteratorCurrentNull { + get { + return ResourceManager.GetString("ErrIteratorCurrentNull", resourceCulture); + } + } + } +} diff --git a/Connected.Collections/SR.resx b/Connected.Collections/SR.resx new file mode 100644 index 0000000..f31d215 --- /dev/null +++ b/Connected.Collections/SR.resx @@ -0,0 +1,129 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cannot create queued dispatcher + + + Cannot create service instance + + + Iterator does not have current value + + \ No newline at end of file diff --git a/Connected.Configuration/Authentication/AuthenticationConfiguration.cs b/Connected.Configuration/Authentication/AuthenticationConfiguration.cs new file mode 100644 index 0000000..fbcb489 --- /dev/null +++ b/Connected.Configuration/Authentication/AuthenticationConfiguration.cs @@ -0,0 +1,11 @@ +namespace Connected.Configuration.Authentication; + +internal class AuthenticationConfiguration : IAuthenticationConfiguration +{ + public AuthenticationConfiguration() + { + JwToken = new JwTokenConfiguration(); + } + + public IJwTokenConfiguration JwToken { get; } +} diff --git a/Connected.Configuration/Authentication/IAuthenticationConfiguration.cs b/Connected.Configuration/Authentication/IAuthenticationConfiguration.cs new file mode 100644 index 0000000..5bc36c8 --- /dev/null +++ b/Connected.Configuration/Authentication/IAuthenticationConfiguration.cs @@ -0,0 +1,6 @@ +namespace Connected.Configuration.Authentication; + +public interface IAuthenticationConfiguration +{ + IJwTokenConfiguration JwToken { get; } +} diff --git a/Connected.Configuration/Authentication/IJwTokenConfiguration.cs b/Connected.Configuration/Authentication/IJwTokenConfiguration.cs new file mode 100644 index 0000000..c6bc4e1 --- /dev/null +++ b/Connected.Configuration/Authentication/IJwTokenConfiguration.cs @@ -0,0 +1,10 @@ +namespace Connected.Configuration.Authentication +{ + public interface IJwTokenConfiguration + { + string Issuer { get; } + string Audience { get; } + string Key { get; } + int Duration { get; } + } +} diff --git a/Connected.Configuration/Authentication/JwTokenConfiguration.cs b/Connected.Configuration/Authentication/JwTokenConfiguration.cs new file mode 100644 index 0000000..4129fd5 --- /dev/null +++ b/Connected.Configuration/Authentication/JwTokenConfiguration.cs @@ -0,0 +1,13 @@ +namespace Connected.Configuration.Authentication +{ + internal class JwTokenConfiguration : IJwTokenConfiguration + { + public string? Issuer { get; set; } + + public string? Audience { get; set; } + + public string? Key { get; set; } = "D78RF30487F4G0F8Z34F834F"; + + public int Duration { get; set; } = 30; + } +} diff --git a/Connected.Configuration/ConfigurationService.cs b/Connected.Configuration/ConfigurationService.cs new file mode 100644 index 0000000..6f7062c --- /dev/null +++ b/Connected.Configuration/ConfigurationService.cs @@ -0,0 +1,23 @@ +using Connected.Configuration.Authentication; +using Connected.Configuration.Endpoints; + +namespace Connected.Configuration; + +internal class ConfigurationService : IConfigurationService +{ + public ConfigurationService() + { + Endpoint = new EndpointConfiguration(); + Storage = new StorageConfiguration(); + } + + public IEndpointConfiguration Endpoint { get; } + + public IAuthenticationConfiguration Authentication => throw new NotImplementedException(); + public IStorageConfiguration Storage { get; } + + /* + * TODO: hardcoded + */ + public ProcessType Type => ProcessType.BackEnd; +} diff --git a/Connected.Configuration/ConfigurationStart.cs b/Connected.Configuration/ConfigurationStart.cs new file mode 100644 index 0000000..c3ba3ff --- /dev/null +++ b/Connected.Configuration/ConfigurationStart.cs @@ -0,0 +1,16 @@ +using Connected.Annotations; +using Connected.Configuration.Environment; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Configuration; + +internal class ConfigurationStart : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(IConfigurationService), typeof(ConfigurationService)); + services.AddSingleton(typeof(IEnvironmentService), typeof(EnvironmentService)); + } +} diff --git a/Connected.Configuration/ConfigurationUrls.cs b/Connected.Configuration/ConfigurationUrls.cs new file mode 100644 index 0000000..e158098 --- /dev/null +++ b/Connected.Configuration/ConfigurationUrls.cs @@ -0,0 +1,7 @@ +namespace Connected.Configuration +{ + public static class ConfigurationUrls + { + public const string Settings = "/configuration/settings"; + } +} diff --git a/Connected.Configuration/Connected.Configuration.csproj b/Connected.Configuration/Connected.Configuration.csproj new file mode 100644 index 0000000..fcaf092 --- /dev/null +++ b/Connected.Configuration/Connected.Configuration.csproj @@ -0,0 +1,13 @@ + + + + net7.0 + enable + enable + + + + + + + diff --git a/Connected.Configuration/DatabaseConfiguration.cs b/Connected.Configuration/DatabaseConfiguration.cs new file mode 100644 index 0000000..9935937 --- /dev/null +++ b/Connected.Configuration/DatabaseConfiguration.cs @@ -0,0 +1,23 @@ +using System.Collections.Immutable; + +namespace Connected.Configuration +{ + internal class DatabaseConfiguration : IDatabaseConfiguration + { + private List _shards; + + public DatabaseConfiguration() + { + /* + * TODO: read from config + */ + DefaultConnectionString = "server=PIT-ZBOOK\\sqlexpress; database=connected; trusted_connection=true;TrustServerCertificate=True;multiple active result sets=true"; + + _shards = new(); + } + + public string? DefaultConnectionString { get; init; } + + public ImmutableList Shards => _shards.ToImmutableList(); + } +} diff --git a/Connected.Configuration/Endpoints/EndpointConfiguration.cs b/Connected.Configuration/Endpoints/EndpointConfiguration.cs new file mode 100644 index 0000000..a042b84 --- /dev/null +++ b/Connected.Configuration/Endpoints/EndpointConfiguration.cs @@ -0,0 +1,7 @@ +namespace Connected.Configuration.Endpoints +{ + internal sealed class EndpointConfiguration : IEndpointConfiguration + { + public string? Address { get; set; } + } +} diff --git a/Connected.Configuration/Endpoints/IEndpointConfiguration.cs b/Connected.Configuration/Endpoints/IEndpointConfiguration.cs new file mode 100644 index 0000000..91f7578 --- /dev/null +++ b/Connected.Configuration/Endpoints/IEndpointConfiguration.cs @@ -0,0 +1,7 @@ +namespace Connected.Configuration.Endpoints +{ + public interface IEndpointConfiguration + { + string Address { get; } + } +} diff --git a/Connected.Configuration/Environment/EnvironmentService.cs b/Connected.Configuration/Environment/EnvironmentService.cs new file mode 100644 index 0000000..c5e2001 --- /dev/null +++ b/Connected.Configuration/Environment/EnvironmentService.cs @@ -0,0 +1,33 @@ +using System.Collections.Immutable; +using System.Reflection; +using Connected.Annotations; + +namespace Connected.Configuration.Environment; + +internal class EnvironmentService : IEnvironmentService +{ + private List? _assemblies; + public List All => _assemblies ??= QueryAssemblies(); + + public EnvironmentService() + { + Services = new EnvironmentServices(); + } + + public ImmutableList MicroServices => All.ToImmutableList(); + + public IEnvironmentServices Services { get; } + + private static List QueryAssemblies() + { + var result = new List(); + + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + if (assembly.GetCustomAttribute() is not null) + result.Add(assembly); + } + + return result; + } +} diff --git a/Connected.Configuration/Environment/EnvironmentServices.cs b/Connected.Configuration/Environment/EnvironmentServices.cs new file mode 100644 index 0000000..9164920 --- /dev/null +++ b/Connected.Configuration/Environment/EnvironmentServices.cs @@ -0,0 +1,17 @@ +using System.Collections.Immutable; + +namespace Connected.Configuration.Environment +{ + internal class EnvironmentServices : IEnvironmentServices + { + public ImmutableList Services => RegisteredServices.Services; + + public ImmutableList ServiceMethods => RegisteredServices.ServiceOperations; + + public ImmutableList Arguments => RegisteredServices.Arguments; + + public ImmutableList IoCEndpoints => RegisteredServices.Middleware; + + public ImmutableList EntityCache => RegisteredServices.EntityCache; + } +} diff --git a/Connected.Configuration/Environment/IEnvironmentService.cs b/Connected.Configuration/Environment/IEnvironmentService.cs new file mode 100644 index 0000000..e970175 --- /dev/null +++ b/Connected.Configuration/Environment/IEnvironmentService.cs @@ -0,0 +1,12 @@ +using System.Collections.Immutable; +using System.Reflection; + +namespace Connected.Configuration.Environment +{ + public interface IEnvironmentService + { + ImmutableList MicroServices { get; } + + IEnvironmentServices Services { get; } + } +} diff --git a/Connected.Configuration/Environment/IEnvironmentServices.cs b/Connected.Configuration/Environment/IEnvironmentServices.cs new file mode 100644 index 0000000..404e1cc --- /dev/null +++ b/Connected.Configuration/Environment/IEnvironmentServices.cs @@ -0,0 +1,13 @@ +using System.Collections.Immutable; + +namespace Connected.Configuration.Environment +{ + public interface IEnvironmentServices + { + ImmutableList Services { get; } + ImmutableList ServiceMethods { get; } + ImmutableList Arguments { get; } + ImmutableList IoCEndpoints { get; } + ImmutableList EntityCache { get; } + } +} diff --git a/Connected.Configuration/IConfigurationService.cs b/Connected.Configuration/IConfigurationService.cs new file mode 100644 index 0000000..47823a7 --- /dev/null +++ b/Connected.Configuration/IConfigurationService.cs @@ -0,0 +1,21 @@ +using Connected.Configuration.Authentication; +using Connected.Configuration.Endpoints; + +namespace Connected.Configuration; + +public enum ProcessType +{ + FrontEnd = 1, + BackEnd = 2, + Service = 3 +} + +public interface IConfigurationService +{ + IEndpointConfiguration Endpoint { get; } + + IAuthenticationConfiguration Authentication { get; } + IStorageConfiguration Storage { get; } + + ProcessType Type { get; } +} diff --git a/Connected.Configuration/IDatabaseConfiguration.cs b/Connected.Configuration/IDatabaseConfiguration.cs new file mode 100644 index 0000000..94ce5d4 --- /dev/null +++ b/Connected.Configuration/IDatabaseConfiguration.cs @@ -0,0 +1,11 @@ +using System.Collections.Immutable; + +namespace Connected.Configuration +{ + public interface IDatabaseConfiguration + { + string DefaultConnectionString { get; } + + ImmutableList Shards { get; } + } +} diff --git a/Connected.Configuration/IStorageConfiguration.cs b/Connected.Configuration/IStorageConfiguration.cs new file mode 100644 index 0000000..d5f1392 --- /dev/null +++ b/Connected.Configuration/IStorageConfiguration.cs @@ -0,0 +1,7 @@ +namespace Connected.Configuration +{ + public interface IStorageConfiguration + { + IDatabaseConfiguration Databases { get; } + } +} diff --git a/Connected.Configuration/RegisteredServices.cs b/Connected.Configuration/RegisteredServices.cs new file mode 100644 index 0000000..9c4b249 --- /dev/null +++ b/Connected.Configuration/RegisteredServices.cs @@ -0,0 +1,56 @@ +using System.Collections.Immutable; + +namespace Connected.Configuration +{ + public static class RegisteredServices + { + private static readonly List _services; + private static readonly List _serviceOperations; + private static readonly List _arguments; + private static readonly List _middleware; + private static readonly List _entityCache; + + static RegisteredServices() + { + _services = new List(); + _serviceOperations = new List(); + _arguments = new List(); + _middleware = new List(); + _entityCache = new List(); + } + + public static ImmutableList Services => _services.ToImmutableList(); + + public static ImmutableList ServiceOperations => _serviceOperations.ToImmutableList(); + + public static ImmutableList Arguments => _arguments.ToImmutableList(); + public static ImmutableList Middleware => _middleware.ToImmutableList(); + + public static ImmutableList EntityCache => _entityCache.ToImmutableList(); + + public static void AddApiService(Type type) + { + _services.Add(type); + } + + public static void AddApi(Type type) + { + _serviceOperations.Add(type); + } + + public static void AddArgument(Type type) + { + _arguments.Add(type); + } + + public static void AddMiddleware(Type type) + { + _middleware.Add(type); + } + + public static void AddEntityCache(Type type) + { + _entityCache.Add(type); + } + } +} diff --git a/Connected.Configuration/Settings/ISetting.cs b/Connected.Configuration/Settings/ISetting.cs new file mode 100644 index 0000000..088a50a --- /dev/null +++ b/Connected.Configuration/Settings/ISetting.cs @@ -0,0 +1,9 @@ +using Connected.Data; + +namespace Connected.Configuration.Settings; + +public interface ISetting : IPrimaryKey +{ + string Name { get; init; } + string Value { get; init; } +} diff --git a/Connected.Configuration/Settings/ISettingsService.cs b/Connected.Configuration/Settings/ISettingsService.cs new file mode 100644 index 0000000..f3a21f2 --- /dev/null +++ b/Connected.Configuration/Settings/ISettingsService.cs @@ -0,0 +1,25 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.ServiceModel; + +namespace Connected.Configuration.Settings; + +[Service] +[ServiceUrl(ConfigurationUrls.Settings)] +public interface ISettingsService +{ + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(NameArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get)] + Task> Query(); + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Patch)] + Task Update(SettingsArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Configuration/Settings/SettingsArgs.cs b/Connected.Configuration/Settings/SettingsArgs.cs new file mode 100644 index 0000000..bdccd72 --- /dev/null +++ b/Connected.Configuration/Settings/SettingsArgs.cs @@ -0,0 +1,14 @@ +using System.ComponentModel.DataAnnotations; +using Connected.ServiceModel; + +namespace Connected.Configuration.Settings; + +public class SettingsArgs : Dto +{ + [Required] + [MaxLength(128)] + public string? Name { get; set; } + + [MaxLength(1024)] + public string? Value { get; set; } +} diff --git a/Connected.Configuration/StorageConfiguration.cs b/Connected.Configuration/StorageConfiguration.cs new file mode 100644 index 0000000..67526e5 --- /dev/null +++ b/Connected.Configuration/StorageConfiguration.cs @@ -0,0 +1,12 @@ +namespace Connected.Configuration +{ + internal class StorageConfiguration : IStorageConfiguration + { + public StorageConfiguration() + { + Databases = new DatabaseConfiguration(); + } + + public IDatabaseConfiguration Databases { get; } + } +} diff --git a/Connected.Data/Annotations/ColumnAttribute.cs b/Connected.Data/Annotations/ColumnAttribute.cs new file mode 100644 index 0000000..1637c28 --- /dev/null +++ b/Connected.Data/Annotations/ColumnAttribute.cs @@ -0,0 +1,15 @@ +using Connected.Entities.Annotations; + +namespace Connected.Data.Annotations; + +[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = true)] +internal sealed class ColumnAttribute : MemberAttribute +{ + public string? Name { get; set; } + public string? TableId { get; set; } + public string? DbType { get; set; } + public bool IsComputed { get; set; } + public bool IsPrimaryKey { get; set; } + public bool IsGenerated { get; set; } + public bool IsReadOnly { get; set; } +} diff --git a/Connected.Data/Annotations/MemberExtensionAttribute.cs b/Connected.Data/Annotations/MemberExtensionAttribute.cs new file mode 100644 index 0000000..7746bc3 --- /dev/null +++ b/Connected.Data/Annotations/MemberExtensionAttribute.cs @@ -0,0 +1,8 @@ +namespace Connected.Data.Annotations +{ + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] + internal class MemberExtensionAttribute : Attribute + { + public string TableId { get; set; } + } +} diff --git a/Connected.Data/Annotations/NestedEntityAttribute.cs b/Connected.Data/Annotations/NestedEntityAttribute.cs new file mode 100644 index 0000000..4e0f4a8 --- /dev/null +++ b/Connected.Data/Annotations/NestedEntityAttribute.cs @@ -0,0 +1,9 @@ +using Connected.Entities.Annotations; + +namespace Connected.Data.Annotations; + +[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = true)] +internal class NestedEntityAttribute : MemberAttribute +{ + public Type? RuntimeType { get; set; } +} diff --git a/Connected.Data/Connected.Data.csproj b/Connected.Data/Connected.Data.csproj new file mode 100644 index 0000000..e4874c3 --- /dev/null +++ b/Connected.Data/Connected.Data.csproj @@ -0,0 +1,40 @@ + + + + net7.0 + enable + enable + False + + + + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Data/DataExtensions.cs b/Connected.Data/DataExtensions.cs new file mode 100644 index 0000000..c3963c2 --- /dev/null +++ b/Connected.Data/DataExtensions.cs @@ -0,0 +1,100 @@ +using Connected.Data.Storage; +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using Connected.Interop; +using Connected.ServiceModel; +using System.Data; +using System.Reflection; + +namespace Connected.Data; + +public static class DataExtensions +{ + /// + /// Sets value to the on the + /// provided . + /// + /// The to set the value. + public static void UseIsolatedConnections(this IContext context) + { + if (context.GetService() is IConnectionProvider provider) + provider.Mode = StorageConnectionMode.Isolated; + } + + public static DbType ToDbType(PropertyInfo property) + { + var type = property.PropertyType; + + if (type.IsEnum) + type = Enum.GetUnderlyingType(type); + + if (type == typeof(char) || type == typeof(string)) + { + if (property.FindAttribute() != null) + return DbType.Binary; + + var str = property.FindAttribute(); + + if (str is null) + return DbType.String; + + return str.Kind switch + { + StringKind.NVarChar => DbType.String, + StringKind.VarChar => DbType.AnsiString, + StringKind.Char => DbType.AnsiStringFixedLength, + StringKind.NChar => DbType.StringFixedLength, + _ => DbType.String, + }; + } + else if (type == typeof(byte)) + return DbType.Byte; + else if (type == typeof(bool)) + return DbType.Boolean; + else if (type == typeof(DateTime) || type == typeof(DateTimeOffset)) + { + var att = property.FindAttribute(); + + if (att is null) + return DbType.DateTime2; + + return att.Kind switch + { + DateKind.Date => DbType.Date, + DateKind.DateTime => DbType.DateTime, + DateKind.DateTime2 => DbType.DateTime2, + DateKind.SmallDateTime => DbType.DateTime, + DateKind.Time => DbType.Time, + _ => DbType.DateTime2, + }; + } + else if (type == typeof(decimal)) + return DbType.Decimal; + else if (type == typeof(double)) + return DbType.Double; + else if (type == typeof(Guid)) + return DbType.Guid; + else if (type == typeof(short)) + return DbType.Int16; + else if (type == typeof(int)) + return DbType.Int32; + else if (type == typeof(long)) + return DbType.Int64; + else if (type == typeof(sbyte)) + return DbType.SByte; + else if (type == typeof(float)) + return DbType.Single; + else if (type == typeof(TimeSpan)) + return DbType.Time; + else if (type == typeof(ushort)) + return DbType.UInt16; + else if (type == typeof(uint)) + return DbType.UInt32; + else if (type == typeof(ulong)) + return DbType.UInt64; + else if (type == typeof(byte[])) + return DbType.Binary; + else + return DbType.Binary; + } +} diff --git a/Connected.Data/DataStartup.cs b/Connected.Data/DataStartup.cs new file mode 100644 index 0000000..dda0650 --- /dev/null +++ b/Connected.Data/DataStartup.cs @@ -0,0 +1,23 @@ +using Connected.Annotations; +using Connected.Data.DataProtection; +using Connected.Data.Schema; +using Connected.Data.Sharding; +using Connected.Data.Storage; +using Connected.Entities.Storage; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Data; + +internal sealed class DataStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddScoped(typeof(ISchemaService), typeof(SchemaService)); + services.AddScoped(typeof(IShardingService), typeof(ShardingService)); + services.AddScoped(typeof(IConnectionProvider), typeof(ConnectionProvider)); + services.AddScoped(typeof(IStorageProvider), typeof(StorageProvider)); + services.AddScoped(typeof(IEntityProtectionService), typeof(EntityProtectionService)); + } +} diff --git a/Connected.Data/EntityProtection/EntityProtectionArgs.cs b/Connected.Data/EntityProtection/EntityProtectionArgs.cs new file mode 100644 index 0000000..b47d9b3 --- /dev/null +++ b/Connected.Data/EntityProtection/EntityProtectionArgs.cs @@ -0,0 +1,14 @@ +using Connected.Entities; + +namespace Connected.Data.EntityProtection; +public sealed class EntityProtectionArgs : EventArgs +{ + public EntityProtectionArgs(TEntity entity, State state) + { + Entity = entity; + State = state; + } + + public TEntity Entity { get; } + public State State { get; } +} diff --git a/Connected.Data/EntityProtection/EntityProtectionService.cs b/Connected.Data/EntityProtection/EntityProtectionService.cs new file mode 100644 index 0000000..457b110 --- /dev/null +++ b/Connected.Data/EntityProtection/EntityProtectionService.cs @@ -0,0 +1,25 @@ +using Connected.Data.EntityProtection; +using Connected.Middleware; + +namespace Connected.Data.DataProtection; + +internal class EntityProtectionService : IEntityProtectionService +{ + public EntityProtectionService(IMiddlewareService middleware) + { + Middleware = middleware; + } + + public IMiddlewareService Middleware { get; } + + public async Task Invoke(EntityProtectionArgs args) + { + var middleware = await Middleware.Query>(); + + if (!middleware.Any()) + return; + + foreach (var m in middleware) + await m.Invoke(args); + } +} diff --git a/Connected.Data/EntityProtection/IEntityProtectionService.cs b/Connected.Data/EntityProtection/IEntityProtectionService.cs new file mode 100644 index 0000000..e04fdc0 --- /dev/null +++ b/Connected.Data/EntityProtection/IEntityProtectionService.cs @@ -0,0 +1,9 @@ +using Connected.Data.EntityProtection; + +namespace Connected.Data.DataProtection +{ + public interface IEntityProtectionService + { + Task Invoke(EntityProtectionArgs args); + } +} diff --git a/Connected.Data/EntityProtection/IEntityProtector.cs b/Connected.Data/EntityProtection/IEntityProtector.cs new file mode 100644 index 0000000..c712644 --- /dev/null +++ b/Connected.Data/EntityProtection/IEntityProtector.cs @@ -0,0 +1,8 @@ +using Connected.Data.EntityProtection; + +namespace Connected.Data.DataProtection; + +public interface IEntityProtector : IMiddleware +{ + Task Invoke(EntityProtectionArgs args); +} diff --git a/Connected.Data/EntityVersion.cs b/Connected.Data/EntityVersion.cs new file mode 100644 index 0000000..c46f442 --- /dev/null +++ b/Connected.Data/EntityVersion.cs @@ -0,0 +1,125 @@ +using Connected.Interop; + +namespace Connected.Data; + +internal class EntityVersion : IComparable, IEquatable, IComparable +{ + public static readonly EntityVersion? Zero = default; + + private readonly ulong Value; + + private EntityVersion(ulong value) + { + Value = value; + } + + public static EntityVersion? Parse(object value) + { + if (!TypeConversion.TryConvert(value, out string? v)) + return Zero; + + if (string.IsNullOrWhiteSpace(v)) + return Zero; + + return new EntityVersion(Convert.ToUInt64(v, 16)); + } + + public static implicit operator EntityVersion(ulong value) + { + return new EntityVersion(value); + } + + public static implicit operator EntityVersion(long value) + { + return new EntityVersion(unchecked((ulong)value)); + } + + public static explicit operator EntityVersion?(byte[] value) + { + if (value is null) + return null; + + return new EntityVersion((ulong)value[0] << 56 | (ulong)value[1] << 48 | (ulong)value[2] << 40 | (ulong)value[3] << 32 | (ulong)value[4] << 24 | (ulong)value[5] << 16 | (ulong)value[6] << 8 | value[7]); + } + + public static implicit operator byte[](EntityVersion timestamp) + { + var r = new byte[8]; + + r[0] = (byte)(timestamp.Value >> 56); + r[1] = (byte)(timestamp.Value >> 48); + r[2] = (byte)(timestamp.Value >> 40); + r[3] = (byte)(timestamp.Value >> 32); + r[4] = (byte)(timestamp.Value >> 24); + r[5] = (byte)(timestamp.Value >> 16); + r[6] = (byte)(timestamp.Value >> 8); + r[7] = (byte)timestamp.Value; + + return r; + } + + public override bool Equals(object? obj) + { + return obj is Version version && Equals(version); + } + + public override int GetHashCode() + { + return Value.GetHashCode(); + } + + public bool Equals(EntityVersion? other) + { + return other?.Value == Value; + } + + int IComparable.CompareTo(object? obj) + { + return obj is null ? 1 : CompareTo((EntityVersion)obj); + } + + public int CompareTo(EntityVersion? other) + { + return Value == other?.Value ? 0 : Value < other?.Value ? -1 : 1; + } + + public static bool operator ==(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.Equals(comparand2); + } + + public static bool operator !=(EntityVersion comparand1, EntityVersion comparand2) + { + return !comparand1.Equals(comparand2); + } + + public static bool operator >(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.CompareTo(comparand2) > 0; + } + + public static bool operator >=(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.CompareTo(comparand2) >= 0; + } + + public static bool operator <(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.CompareTo(comparand2) < 0; + } + + public static bool operator <=(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.CompareTo(comparand2) <= 0; + } + + public override string ToString() + { + return Value.ToString("x16"); + } + + public static EntityVersion Max(EntityVersion comparand1, EntityVersion comparand2) + { + return comparand1.Value < comparand2.Value ? comparand2 : comparand1; + } +} \ No newline at end of file diff --git a/Connected.Data/FieldMappings.cs b/Connected.Data/FieldMappings.cs new file mode 100644 index 0000000..05a6e81 --- /dev/null +++ b/Connected.Data/FieldMappings.cs @@ -0,0 +1,192 @@ +using Connected.Entities.Annotations; +using Connected.Interop; +using System.Data; +using System.Reflection; + +namespace Connected.Data; +/// +/// Performs mapping between fields and entity properties. +/// +/// The entity type to be used. +internal class FieldMappings +{ + private Dictionary _properties; + /// + /// Creates a new object. + /// + /// The providing entity records. + public FieldMappings(IDataReader reader) + { + Initialize(reader); + } + /// + /// Cached properties use when looping through the records. + /// + private Dictionary Properties => _properties; + /// + /// Initializes the mappings base on the provided + /// and + /// + /// The active reader containing records. + private void Initialize(IDataReader reader) + { + /* + * For primitive types there are no mappings since it's an scalar call. + */ + if (typeof(TEntity).IsTypePrimitive()) + return; + + _properties = new Dictionary(); + /* + * We are binding only properties, not fields. + */ + var properties = typeof(TEntity).GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + for (var i = 0; i < reader.FieldCount; i++) + { + if (FieldMappings.ResolveProperty(properties, reader.GetName(i)) is PropertyInfo property) + _properties.Add(i, property); + } + } + /// + /// Creates a new instance of the and binds + /// properties from the provided . + /// + /// The containing the record. + /// A new instance of the with bound values from the . + public TEntity? CreateInstance(IDataReader reader) + { + /* + * For primitive return values we'll use only the first field and return itž + * to the caller. + */ + if (typeof(TEntity).IsTypePrimitive()) + { + if (reader.FieldCount == 0) + return default; + + if (TypeConversion.TryConvert(reader[0], out TEntity? result)) + return result; + + return default; + } + /* + * It's an actual entity. First, create a new instance. Entities should have + * public parameterless constructor. + */ + if (typeof(TEntity?).CreateInstance() is not TEntity instance) + throw new NullReferenceException(typeof(TEntity).FullName); + + foreach (var property in Properties) + Bind(instance, property, reader); + + return instance; + } + /// + /// Resolves a correct property from the entity's properties based on a field name. + /// + /// The entity's properties. + /// The field name. + /// A if found, null otherwise. + private static PropertyInfo? ResolveProperty(PropertyInfo[] properties, string name) + { + /* + * There are two ways to map a property (evaluated in the following order): + * 1. from property name + * 2. from MemberAttribute + * + * We'll first perform case insensitive comparison because fields in the database are usually stored in a camelCase format. + */ + if (properties.FirstOrDefault(f => string.Equals(f.Name, name, StringComparison.OrdinalIgnoreCase)) is PropertyInfo property && property.CanWrite) + { + /* + * Property is found, examine if the persistence from the storage is supported. + */ + var att = property.FindAttribute(); + + if (att is null || att.Persistence.HasFlag(ColumnPersistence.Read)) + return property; + } + /* + * Property wasn't found, let's try to find it via MemberAttribute. + */ + foreach (var prop in properties) + { + /* + * It's case insensitive comparison again because we don't want bother user with exact matching. Since a database is probably case insensitive anyway + * there is no option to store columns with case sensitive names. + */ + if (prop.FindAttribute() is MemberAttribute nameAttribute && string.Compare(nameAttribute.Member, name, true) == 0 && prop.CanWrite) + return prop; + } + /* + * Property could't be found. The field will be ognored when reading data. + */ + return default; + } + /// + /// Binds the value to the entity's property. + /// + /// The instance of the entity. + /// The property on which value to be set. + /// The providing the value. + private static void Bind(object instance, KeyValuePair property, IDataReader reader) + { + var value = reader.GetValue(property.Key); + /* + * We won't bind null values. We'll leave the property as is. + */ + if (value is null || Convert.IsDBNull(value)) + return; + /* + * We have a few exceptions when binding values. + */ + if (property.Value.PropertyType == typeof(string) && value is byte[] bv) + { + /* + * If the property is string and the reader's value is byte array we are probably dealing + * with Consistency field. We'll first check if the property contains the attribute. If so, + * we'll convert byte array to eTag kind of value. If not we'll simply convert value to base64 + * string. + */ + if (property.Value.FindAttribute() is not null) + { + var versionValue = (EntityVersion?)bv; + + if (versionValue is null) + value = Convert.ToBase64String(bv); + else + value = versionValue.ToString(); + } + else + value = Convert.ToBase64String(bv); + } + else if (property.Value.PropertyType == typeof(DateTimeOffset)) + { + /* + * We don't perform any conversions on dates. All dates should be stored in a UTC + * format so we simply set the correct kind of date so it can be later correctly + * converted + */ + value = new DateTimeOffset(DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc)); + } + else if (property.Value.PropertyType == typeof(DateTime)) + { + /* + * Like DateTimeOffset, the same is true for DateTime values + */ + value = DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc); + } + else + { + /* + * For other values we just perform a conversion. + */ + value = TypeConversion.Convert(value, property.Value.PropertyType); + } + /* + * Now bind the property from the converted value. + */ + property.Value.SetValue(instance, value); + } +} diff --git a/Connected.Data/SR.Designer.cs b/Connected.Data/SR.Designer.cs new file mode 100644 index 0000000..f2a7a31 --- /dev/null +++ b/Connected.Data/SR.Designer.cs @@ -0,0 +1,72 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Data { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Connected.Data.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Data concurrency issue occured. + /// + internal static string ErrDataConcurrency { + get { + return ResourceManager.GetString("ErrDataConcurrency", resourceCulture); + } + } + } +} diff --git a/Connected.Data/SR.resx b/Connected.Data/SR.resx new file mode 100644 index 0000000..721e9d2 --- /dev/null +++ b/Connected.Data/SR.resx @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Data concurrency issue occured + + \ No newline at end of file diff --git a/Connected.Data/Schema/EntitySchema.cs b/Connected.Data/Schema/EntitySchema.cs new file mode 100644 index 0000000..9bb4d16 --- /dev/null +++ b/Connected.Data/Schema/EntitySchema.cs @@ -0,0 +1,38 @@ +namespace Connected.Data.Schema +{ + internal class EntitySchema : ISchema + { + public EntitySchema() + { + Columns = new(); + } + public List Columns { get; } + + public string? Schema { get; set; } + public string? Name { get; set; } + public string? Type { get; set; } + public bool Ignore { get; set; } + public bool Equals(ISchema? other) + { + if (other is null) + return false; + + if (!string.Equals(Name, other.Name, StringComparison.Ordinal)) + return false; + + if (!string.Equals(Schema, other.Schema, StringComparison.Ordinal)) + return false; + + if (Columns.Count != other.Columns.Count) + return false; + + for (var i = 0; i < Columns.Count; i++) + { + if (Columns[i] is not IEquatable left || !left.Equals(other.Columns[i])) + return false; + } + + return true; + } + } +} diff --git a/Connected.Data/Schema/ExistingColumn.cs b/Connected.Data/Schema/ExistingColumn.cs new file mode 100644 index 0000000..9c4efd4 --- /dev/null +++ b/Connected.Data/Schema/ExistingColumn.cs @@ -0,0 +1,62 @@ +using System.Collections.Immutable; +using System.Data; +using Connected.Data.Schema.Sql; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema; + +internal class ExistingColumn : ISchemaColumn, IExistingSchemaColumn +{ + public ExistingColumn(ISchema schema) + { + Schema = schema; + } + + private ISchema Schema { get; } + + public string Name { get; set; } + + public DbType DataType { get; set; } + + public bool IsIdentity { get; set; } + public bool IsVersion { get; set; } + + public bool IsUnique { get; set; } + + public bool IsIndex { get; set; } + + public bool IsPrimaryKey { get; set; } + + public string DefaultValue { get; set; } + + public int MaxLength { get; set; } + + public bool IsNullable { get; set; } + + public string DependencyType { get; set; } + + public string DependencyProperty { get; set; } + + public string Index { get; set; } + public int Precision { get; set; } + public int Scale { get; set; } + + public DateKind DateKind { get; set; } = DateKind.DateTime; + public BinaryKind BinaryKind { get; set; } = BinaryKind.VarBinary; + + public int DatePrecision { get; set; } + + public ImmutableArray QueryIndexColumns(string column) + { + if (Schema is not ExistingSchema existing) + return ImmutableArray.Empty; + + foreach (var index in existing.Indexes) + { + if (index.Columns.Contains(column, StringComparer.OrdinalIgnoreCase)) + return index.Columns.ToImmutableArray(); + } + + return ImmutableArray.Empty; + } +} diff --git a/Connected.Data/Schema/IDatabase.cs b/Connected.Data/Schema/IDatabase.cs new file mode 100644 index 0000000..3dcb0e4 --- /dev/null +++ b/Connected.Data/Schema/IDatabase.cs @@ -0,0 +1,7 @@ +namespace Connected.Data.Schema +{ + internal interface IDatabase + { + List Tables { get; } + } +} diff --git a/Connected.Data/Schema/IExistingSchemaColumn.cs b/Connected.Data/Schema/IExistingSchemaColumn.cs new file mode 100644 index 0000000..6500133 --- /dev/null +++ b/Connected.Data/Schema/IExistingSchemaColumn.cs @@ -0,0 +1,9 @@ +using System.Collections.Immutable; + +namespace Connected.Data.Schema +{ + internal interface IExistingSchemaColumn + { + ImmutableArray QueryIndexColumns(string column); + } +} diff --git a/Connected.Data/Schema/IReferentialConstraint.cs b/Connected.Data/Schema/IReferentialConstraint.cs new file mode 100644 index 0000000..65ce849 --- /dev/null +++ b/Connected.Data/Schema/IReferentialConstraint.cs @@ -0,0 +1,12 @@ +namespace Connected.Data.Schema +{ + internal interface IReferentialConstraint + { + string Name { get; } + string ReferenceSchema { get; } + string ReferenceName { get; } + string MatchOption { get; } + string UpdateRule { get; } + string DeleteRule { get; } + } +} diff --git a/Connected.Data/Schema/ISchema.cs b/Connected.Data/Schema/ISchema.cs new file mode 100644 index 0000000..d3cdc61 --- /dev/null +++ b/Connected.Data/Schema/ISchema.cs @@ -0,0 +1,12 @@ +namespace Connected.Data.Schema +{ + public interface ISchema : IEquatable + { + List Columns { get; } + + string? Schema { get; } + string? Name { get; } + string? Type { get; } + bool Ignore { get; } + } +} diff --git a/Connected.Data/Schema/ISchemaColumn.cs b/Connected.Data/Schema/ISchemaColumn.cs new file mode 100644 index 0000000..d24fcfc --- /dev/null +++ b/Connected.Data/Schema/ISchemaColumn.cs @@ -0,0 +1,24 @@ +using System.Data; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema; + +public interface ISchemaColumn +{ + string? Name { get; } + DbType DataType { get; } + bool IsIdentity { get; } + bool IsUnique { get; } + bool IsIndex { get; } + bool IsPrimaryKey { get; } + bool IsVersion { get; } + string? DefaultValue { get; } + int MaxLength { get; } + bool IsNullable { get; } + string? Index { get; } + int Scale { get; } + int Precision { get; } + DateKind DateKind { get; } + BinaryKind BinaryKind { get; } + int DatePrecision { get; } +} diff --git a/Connected.Data/Schema/ISchemaMiddleware.cs b/Connected.Data/Schema/ISchemaMiddleware.cs new file mode 100644 index 0000000..96f4b66 --- /dev/null +++ b/Connected.Data/Schema/ISchemaMiddleware.cs @@ -0,0 +1,9 @@ +namespace Connected.Data.Schema; + +public interface ISchemaMiddleware : IMiddleware +{ + Task IsEntitySupported(Type entity); + Task Synchronize(Type entity, ISchema schema); + Type ConnectionType { get; } + string DefaultConnectionString { get; } +} diff --git a/Connected.Data/Schema/ISchemaService.cs b/Connected.Data/Schema/ISchemaService.cs new file mode 100644 index 0000000..ca01dec --- /dev/null +++ b/Connected.Data/Schema/ISchemaService.cs @@ -0,0 +1,7 @@ +namespace Connected.Data.Schema +{ + public interface ISchemaService + { + Task Synchronize(List? entities); + } +} diff --git a/Connected.Data/Schema/ISchemaSynchronizationContext.cs b/Connected.Data/Schema/ISchemaSynchronizationContext.cs new file mode 100644 index 0000000..e32d9bf --- /dev/null +++ b/Connected.Data/Schema/ISchemaSynchronizationContext.cs @@ -0,0 +1,7 @@ +namespace Connected.Data.Schema; + +public interface ISchemaSynchronizationContext +{ + Type ConnectionType { get; } + string ConnectionString { get; } +} diff --git a/Connected.Data/Schema/ITable.cs b/Connected.Data/Schema/ITable.cs new file mode 100644 index 0000000..e20f952 --- /dev/null +++ b/Connected.Data/Schema/ITable.cs @@ -0,0 +1,8 @@ +namespace Connected.Data.Schema +{ + internal interface ITable : ISchema + { + List Columns { get; } + List Indexes { get; } + } +} diff --git a/Connected.Data/Schema/ITableColumn.cs b/Connected.Data/Schema/ITableColumn.cs new file mode 100644 index 0000000..85c5e82 --- /dev/null +++ b/Connected.Data/Schema/ITableColumn.cs @@ -0,0 +1,22 @@ +namespace Connected.Data.Schema +{ + internal interface ITableColumn + { + string Name { get; } + string DataType { get; } + bool Identity { get; } + bool IsNullable { get; } + string DefaultValue { get; } + int Ordinal { get; } + int CharacterMaximumLength { get; } + int CharacterOctetLength { get; } + int NumericPrecision { get; } + int NumericPrecisionRadix { get; } + int NumericScale { get; } + int DateTimePrecision { get; } + string CharacterSetName { get; } + + IReferentialConstraint Reference { get; } + List Constraints { get; } + } +} diff --git a/Connected.Data/Schema/ITableConstraint.cs b/Connected.Data/Schema/ITableConstraint.cs new file mode 100644 index 0000000..8bdf370 --- /dev/null +++ b/Connected.Data/Schema/ITableConstraint.cs @@ -0,0 +1,6 @@ +namespace Connected.Data.Schema +{ + internal interface ITableConstraint : ISchema + { + } +} diff --git a/Connected.Data/Schema/ITableIndex.cs b/Connected.Data/Schema/ITableIndex.cs new file mode 100644 index 0000000..2301905 --- /dev/null +++ b/Connected.Data/Schema/ITableIndex.cs @@ -0,0 +1,8 @@ +namespace Connected.Data.Schema +{ + internal interface ITableIndex + { + string Name { get; } + List Columns { get; } + } +} diff --git a/Connected.Data/Schema/SchemaColumn.cs b/Connected.Data/Schema/SchemaColumn.cs new file mode 100644 index 0000000..77f922d --- /dev/null +++ b/Connected.Data/Schema/SchemaColumn.cs @@ -0,0 +1,120 @@ +using System.Data; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema; + +internal class SchemaColumn : IEquatable, ISchemaColumn +{ + public SchemaColumn(ISchema schema) + { + Schema = schema; + } + private ISchema Schema { get; } + + public string? Name { get; set; } + public DbType DataType { get; set; } + public bool IsIdentity { get; set; } + public bool IsUnique { get; set; } + public bool IsVersion { get; set; } + public bool IsIndex { get; set; } + public bool IsPrimaryKey { get; set; } + public string? DefaultValue { get; set; } + public int MaxLength { get; set; } + public bool IsNullable { get; set; } + public string? Index { get; set; } + public int Precision { get; set; } + public int Scale { get; set; } + public DateKind DateKind { get; set; } = DateKind.DateTime; + public BinaryKind BinaryKind { get; set; } = BinaryKind.VarBinary; + public int DatePrecision { get; set; } + + public int Ordinal { get; set; } + + public bool Equals(ISchemaColumn? other) + { + if (other is null) + return false; + + if (!string.Equals(Name, other.Name, StringComparison.OrdinalIgnoreCase)) + return false; + + if (DataType != other.DataType) + return false; + + if (IsIdentity != other.IsIdentity) + return false; + + if (IsUnique != other.IsUnique) + return false; + + if (IsIndex != other.IsIndex) + return false; + + if (IsVersion != other.IsVersion) + return false; + + if (IsPrimaryKey != other.IsPrimaryKey) + return false; + + if (Precision != other.Precision) + return false; + + if (Scale != other.Scale) + return false; + + if (!string.Equals(DefaultValue, other.DefaultValue, StringComparison.Ordinal)) + return false; + + if (MaxLength != other.MaxLength) + return false; + + if (IsNullable != other.IsNullable) + return false; + + if (DateKind != other.DateKind) + return false; + + if (DatePrecision != other.DatePrecision) + return false; + + if (BinaryKind != other.BinaryKind) + return false; + + if (other is IExistingSchemaColumn existing) + { + var existingColumns = existing.QueryIndexColumns(Name); + + if (existingColumns.Any() || IsIndex) + { + var columns = new List(); + + if (!string.IsNullOrWhiteSpace(Index)) + { + foreach (var column in Schema.Columns) + { + if (string.Equals(column.Index, Index, StringComparison.OrdinalIgnoreCase)) + columns.Add(column.Name); + } + } + else + columns.Add(Name); + + if (existingColumns.Length != columns.Count) + return false; + + existingColumns = existingColumns.Sort(); + columns.Sort(); + + for (var i = 0; i < existingColumns.Length; i++) + { + if (!string.Equals(existingColumns[i], columns[i], StringComparison.OrdinalIgnoreCase)) + return false; + } + } + } + else + return string.Equals(Index, other.Index, StringComparison.Ordinal); + + return true; + } +} diff --git a/Connected.Data/Schema/SchemaService.cs b/Connected.Data/Schema/SchemaService.cs new file mode 100644 index 0000000..6423c8e --- /dev/null +++ b/Connected.Data/Schema/SchemaService.cs @@ -0,0 +1,249 @@ +using System.Collections.Immutable; +using System.Data; +using System.Globalization; +using System.Reflection; +using Connected.Annotations; +using Connected.Entities.Annotations; +using Connected.Interop; +using Connected.Middleware; +using Connected.Threading; +using Microsoft.Extensions.Logging; + +namespace Connected.Data.Schema; + +internal class SchemaService : ISchemaService +{ + public SchemaService(IMiddlewareService middleware, ILogger logger) + { + Middleware = new AsyncLazy>(middleware.Query()); + Logger = logger; + } + + private ILogger Logger { get; } + private AsyncLazy> Middleware { get; } + + public async Task Synchronize(List? entities) + { + if (Middleware is null) + { + Logger.LogWarning("No ISchemaMiddleware is registered."); + return; + } + + if (entities is null || !entities.Any()) + return; + + foreach (var entity in entities) + { + if (!IsPersistent(entity)) + continue; + + Logger.LogTrace("Synchronizing entity '{entity}", entity.Name); + + var synchronized = false; + var schema = CreateSchema(entity); + + if (schema.Ignore) + continue; + + foreach (var middleware in await Middleware.Value) + { + /* + * We are looking for the first middleware which returns true, + * which means it supports entity synchronization. + */ + if (!await middleware.IsEntitySupported(entity)) + continue; + /* + * Note that sharding synchronization will be handled by the middleware. + */ + await middleware.Synchronize(entity, schema); + + synchronized = true; + } + /* + * We should notify the environment that entity is no synchronized. + * Maybe we should throw the exception here because unsynchronized + * entities could cause system instabillity. + */ + if (!synchronized) + Logger.LogWarning("No middleware synchronized the entity ({entity}).", entity.Name); + } + } + /// + /// Determines if the entity supports persistence. Virtual entities does not support persistence which + /// means they don't have physical storage. + /// + /// The type of the entity to check for persistence. + /// true if the entity supports persistence, false otherwise. + private static bool IsPersistent(Type entityType) + { + var persistence = entityType.GetCustomAttribute(); + + return persistence is null || !persistence.Persistence.HasFlag(ColumnPersistence.Write); + } + + private static ISchema CreateSchema(Type type) + { + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic); + var att = ResolveSchemaAttribute(type); + + var result = new EntitySchema + { + Name = att.Name, + Schema = att.Schema + }; + + var columns = new List(); + + foreach (var property in properties) + { + if (!property.CanWrite) + continue; + + if (property.FindAttribute() is PersistenceAttribute pa && pa.IsVirtual) + continue; + + var column = new SchemaColumn(result) + { + Name = ResolveColumnName(property), + DataType = DataExtensions.ToDbType(property) + }; + + var pk = property.FindAttribute(); + + if (pk != null) + { + column.IsPrimaryKey = true; + column.IsIdentity = pk.Identity; + column.IsUnique = true; + column.IsIndex = true; + } + + var idx = property.FindAttribute(); + + if (idx != null) + { + column.IsIndex = true; + column.IsUnique = idx.Unique; + column.Index = idx.Name; + } + + var ordinal = property.FindAttribute(); + + if (ordinal != null) + column.Ordinal = ordinal.Value; + + if (column.DataType == DbType.Decimal + || column.DataType == DbType.VarNumeric) + { + var numeric = property.FindAttribute(); + + if (numeric != null) + { + column.Precision = numeric.Percision; + column.Scale = numeric.Scale; + } + else + { + column.Precision = 20; + column.Scale = 5; + } + + } + else if (column.DataType == DbType.Date + || column.DataType == DbType.DateTime + || column.DataType == DbType.DateTime2 + || column.DataType == DbType.DateTimeOffset + || column.DataType == DbType.Time) + { + var dateAtt = property.FindAttribute(); + + if (dateAtt is not null) + { + column.DateKind = dateAtt.Kind; + column.DatePrecision = dateAtt.Precision; + } + else + { + column.DateKind = DateKind.DateTime2; + column.DatePrecision = 7; + } + } + else if (column.DataType == DbType.Binary) + { + var bin = property.FindAttribute(); + + if (bin is not null) + column.BinaryKind = bin.Kind; + } + else if (column.DataType == DbType.String + || column.DataType == DbType.AnsiString + || column.DataType == DbType.AnsiStringFixedLength + || column.DataType == DbType.StringFixedLength) + { + column.MaxLength = 50; + } + + ParseDefaultValue(column, property); + + if (property.FindAttribute() is not null) + column.IsVersion = true; + else + { + var maxLength = property.FindAttribute(); + + if (maxLength is not null) + column.MaxLength = maxLength.Value; + } + + var nullable = property.FindAttribute(); + + if (nullable is null) + column.IsNullable = property.PropertyType.IsNullableType(); + else + column.IsNullable = nullable.IsNullable; + + columns.Add(column); + } + + if (columns.Any()) + result.Columns.AddRange(columns.OrderBy(f => f.Ordinal).ThenBy(f => f.Name)); + + return result; + } + + private static SchemaAttribute ResolveSchemaAttribute(Type type) + { + var att = type.GetCustomAttribute() ?? new TableAttribute(); + + if (string.IsNullOrWhiteSpace(att.Name)) + att.Name = type.Name.ToCamelCase(); + + if (string.IsNullOrEmpty(att.Schema)) + att.Schema = SchemaAttribute.DefaultSchema; + + return att; + } + + private static string ResolveColumnName(PropertyInfo property) + { + if (property.FindAttribute() is not MemberAttribute mapping || string.IsNullOrWhiteSpace(mapping.Member)) + return property.Name.ToCamelCase(); + + return mapping.Member; + } + + private static void ParseDefaultValue(SchemaColumn column, PropertyInfo property) + { + if (property.FindAttribute() is not DefaultAttribute def) + return; + + var value = def.Value; + + if (def.Value is not null && def.Value.GetType().IsEnum) + value = TypeConversion.Convert(def.Value, def.Value.GetType().GetEnumUnderlyingType()); + + column.DefaultValue = TypeConversion.Convert(value, CultureInfo.InvariantCulture); + } +} diff --git a/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs b/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs new file mode 100644 index 0000000..7aa3d8b --- /dev/null +++ b/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs @@ -0,0 +1,11 @@ +using Connected.Entities; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema.Sql; + +[Persistence(Persistence = ColumnPersistence.InMemory)] +internal sealed class AdHocSchemaEntity : IEntity +{ + public State State { get; init; } + public bool Result { get; init; } +} diff --git a/Connected.Data/Schema/Sql/ColumnAdd.cs b/Connected.Data/Schema/Sql/ColumnAdd.cs new file mode 100644 index 0000000..8871a2a --- /dev/null +++ b/Connected.Data/Schema/Sql/ColumnAdd.cs @@ -0,0 +1,35 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class ColumnAdd : ColumnTransaction + { + public ColumnAdd(ISchemaColumn column) : base(column) + { + } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + + if (Column.IsPrimaryKey) + await new PrimaryKeyAdd(Column).Execute(Context); + + if (!string.IsNullOrWhiteSpace(Column.DefaultValue)) + await new DefaultAdd(Column, Context.Schema.Name).Execute(Context); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + text.AppendLine($"ADD COLUMN {CreateColumnCommandText(Column)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ColumnAlter.cs b/Connected.Data/Schema/Sql/ColumnAlter.cs new file mode 100644 index 0000000..c3b5a9a --- /dev/null +++ b/Connected.Data/Schema/Sql/ColumnAlter.cs @@ -0,0 +1,60 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class ColumnAlter : ColumnTransaction + { + public ColumnAlter(ISchemaColumn column, ExistingSchema existing, ISchemaColumn existingColumn) : base(column) + { + Existing = existing; + ExistingColumn = existingColumn; + } + + private ExistingSchema Existing { get; } + private ISchemaColumn ExistingColumn { get; } + + protected override async Task OnExecute() + { + if (ExistingColumn.Equals(Column)) + return; + + if (!string.IsNullOrWhiteSpace(ExistingColumn.DefaultValue)) + { + var existingDefault = SchemaExtensions.ParseDefaultValue(ExistingColumn.DefaultValue); + var def = SchemaExtensions.ParseDefaultValue(Column.DefaultValue); + + if (!string.Equals(existingDefault, def, StringComparison.Ordinal)) + await new DefaultDrop(Column).Execute(Context); + } + if (Column.DataType != ExistingColumn.DataType + || Column.IsNullable != ExistingColumn.IsNullable + || Column.MaxLength != ExistingColumn.MaxLength + || Column.IsVersion != ExistingColumn.IsVersion) + await Context.Execute(CommandText); + + var ed = SchemaExtensions.ParseDefaultValue(ExistingColumn.DefaultValue); + var nd = SchemaExtensions.ParseDefaultValue(Column.DefaultValue); + + if (!string.Equals(ed, nd, StringComparison.Ordinal) && nd is not null) + await new DefaultAdd(Column, Context.Schema.Name).Execute(Context); + + if (!ExistingColumn.IsPrimaryKey && Column.IsPrimaryKey) + await new PrimaryKeyAdd(Column).Execute(Context); + else if (ExistingColumn.IsPrimaryKey && !Column.IsPrimaryKey) + await new PrimaryKeyRemove(Existing, ExistingColumn).Execute(Context); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + text.AppendLine($"ALTER COLUMN {CreateColumnCommandText(Column)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ColumnDrop.cs b/Connected.Data/Schema/Sql/ColumnDrop.cs new file mode 100644 index 0000000..716aba1 --- /dev/null +++ b/Connected.Data/Schema/Sql/ColumnDrop.cs @@ -0,0 +1,39 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class ColumnDrop : ColumnTransaction + { + public ColumnDrop(ISchemaColumn column, ExistingSchema existing) : base(column) + { + Existing = existing; + } + + private ExistingSchema Existing { get; } + + protected override async Task OnExecute() + { + if (!string.IsNullOrWhiteSpace(Column.DefaultValue)) + await new DefaultDrop(Column).Execute(Context); + + var indexes = Existing.ResolveIndexes(Column.Name); + + foreach (var index in indexes) + await new IndexDrop(index).Execute(Context); + + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)} DROP COLUMN {Column.Name};"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ColumnTransaction.cs b/Connected.Data/Schema/Sql/ColumnTransaction.cs new file mode 100644 index 0000000..87e1c59 --- /dev/null +++ b/Connected.Data/Schema/Sql/ColumnTransaction.cs @@ -0,0 +1,12 @@ +namespace Connected.Data.Schema.Sql +{ + internal abstract class ColumnTransaction : TableTransaction + { + public ColumnTransaction(ISchemaColumn column) + { + Column = column; + } + + protected ISchemaColumn Column { get; } + } +} diff --git a/Connected.Data/Schema/Sql/Columns.cs b/Connected.Data/Schema/Sql/Columns.cs new file mode 100644 index 0000000..425ce52 --- /dev/null +++ b/Connected.Data/Schema/Sql/Columns.cs @@ -0,0 +1,83 @@ +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using System.Data; +using System.Text; + +namespace Connected.Data.Schema.Sql; + +internal class Columns : SynchronizationQuery> +{ + public Columns(ExistingSchema existing) + { + Existing = existing; + } + + private ExistingSchema Existing { get; } + + protected override async Task> OnExecute() + { + var result = new List(); + var rdr = await Context.OpenReader(new StorageOperation { CommandText = CommandText }); + + while (rdr.Read()) + { + var column = new ExistingColumn(Existing) + { + IsNullable = !string.Equals(rdr.GetValue("IS_NULLABLE", string.Empty), "NO", StringComparison.OrdinalIgnoreCase), + DataType = SchemaExtensions.ToDbType(rdr.GetValue("DATA_TYPE", string.Empty)), + MaxLength = rdr.GetValue("CHARACTER_MAXIMUM_LENGTH", 0), + Name = rdr.GetValue("COLUMN_NAME", string.Empty), + }; + + if (column.DataType == DbType.Decimal || column.DataType == DbType.VarNumeric) + { + column.Precision = rdr.GetValue("NUMERIC_PRECISION", 0); + column.Scale = rdr.GetValue("NUMERIC_SCALE", 0); + } + + if (column.DataType == DbType.DateTime2 + || column.DataType == DbType.Time + || column.DataType == DbType.DateTimeOffset) + column.DatePrecision = rdr.GetValue("DATETIME_PRECISION", 0); + + if (column.DataType == DbType.Date) + column.DateKind = DateKind.Date; + else if (column.DataType == DbType.DateTime) + { + if (string.Compare(rdr.GetValue("DATA_TYPE", string.Empty), "smalldatetime", true) == 0) + column.DateKind = DateKind.SmallDateTime; + } + else if (column.DataType == DbType.DateTime2) + column.DateKind = DateKind.DateTime2; + else if (column.DataType == DbType.Time) + column.DateKind = DateKind.Time; + else if (column.DataType == DbType.Binary) + { + if (string.Compare(rdr.GetValue("DATA_TYPE", string.Empty), "varbinary", true) == 0) + column.BinaryKind = BinaryKind.VarBinary; + else if (string.Compare(rdr.GetValue("DATA_TYPE", string.Empty), "binary", true) == 0) + column.BinaryKind = BinaryKind.Binary; + } + + column.IsVersion = string.Equals(rdr.GetValue("DATA_TYPE", string.Empty), "timestamp", StringComparison.OrdinalIgnoreCase); + + result.Add(column); + } + + rdr.Close(); + + return result; + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '{Context.Schema.SchemaName()}' AND TABLE_NAME = '{Context.Schema.Name}'"); + + return text.ToString(); + } + } +} diff --git a/Connected.Data/Schema/Sql/ConstraintDrop.cs b/Connected.Data/Schema/Sql/ConstraintDrop.cs new file mode 100644 index 0000000..15b80d3 --- /dev/null +++ b/Connected.Data/Schema/Sql/ConstraintDrop.cs @@ -0,0 +1,30 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class ConstraintDrop : TableTransaction + { + public ConstraintDrop(ObjectIndex index) + { + Index = index; + } + + private ObjectIndex Index { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.Schema, Context.Schema.Name)} DROP CONSTRAINT {Index.Name};"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/DataCopy.cs b/Connected.Data/Schema/Sql/DataCopy.cs new file mode 100644 index 0000000..334f3b1 --- /dev/null +++ b/Connected.Data/Schema/Sql/DataCopy.cs @@ -0,0 +1,90 @@ +using System.Data; +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class DataCopy : TableTransaction + { + public DataCopy(ExistingSchema existing, string temporaryName) + { + Existing = existing; + TemporaryName = temporaryName; + } + + private ExistingSchema Existing { get; } + public string TemporaryName { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + var columnSet = new StringBuilder(); + var sourceSet = new StringBuilder(); + var comma = string.Empty; + + foreach (var column in Context.Schema.Columns) + { + if (column.IsVersion) + continue; + + var existing = Existing.Columns.FirstOrDefault(f => string.Equals(column.Name, f.Name, StringComparison.OrdinalIgnoreCase)); + + if (existing is null) + continue; + + columnSet.Append($"{comma}{Escape(column.Name)}"); + + if (NeedsConversion(column) && (existing.DataType != column.DataType || existing.Precision != column.Precision || existing.Scale != column.Scale)) + sourceSet.Append($"{comma}CONVERT({ConversionString(column)},{Escape(column.Name)})"); + else + sourceSet.Append($"{comma}{Escape(column.Name)}"); + + comma = ","; + } + + text.AppendLine($"IF EXISTS (SELECT * FROM {Escape(Existing.SchemaName(), Existing.Name)})"); + text.AppendLine($"EXEC ('INSERT INTO {Escape(Context.Schema.SchemaName(), TemporaryName)} ({columnSet.ToString()})"); + text.AppendLine($"SELECT {sourceSet.ToString()} FROM {Escape(Existing.SchemaName(), Existing.Name)}')"); + + return text.ToString(); + } + } + + private static string ConversionString(ISchemaColumn column) + { + return column.DataType switch + { + DbType.Byte => "tinyint", + DbType.Currency => "money", + DbType.Decimal => $"decimal({column.Precision}, {column.Scale})", + DbType.Double => "real", + DbType.Int16 => "smallint", + DbType.Int32 => "int", + DbType.Int64 => "bigint", + DbType.SByte => "smallint", + DbType.Single => "float", + DbType.UInt16 => "int", + DbType.UInt32 => "bigint", + DbType.UInt64 => "float", + DbType.VarNumeric => $"numeric({column.Precision}, {column.Scale})", + _ => throw new NotSupportedException(), + }; + } + + private static bool NeedsConversion(ISchemaColumn column) + { + return column.DataType switch + { + DbType.Byte or DbType.Currency or DbType.Decimal or DbType.Double or DbType.Int16 or DbType.Int32 or DbType.Int64 or DbType.SByte + or DbType.Single or DbType.UInt16 or DbType.UInt32 or DbType.UInt64 or DbType.VarNumeric => true, + _ => false, + }; + } + } +} diff --git a/Connected.Data/Schema/Sql/DefaultAdd.cs b/Connected.Data/Schema/Sql/DefaultAdd.cs new file mode 100644 index 0000000..4104422 --- /dev/null +++ b/Connected.Data/Schema/Sql/DefaultAdd.cs @@ -0,0 +1,34 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class DefaultAdd : ColumnTransaction + { + public DefaultAdd(ISchemaColumn column, string tableName) : base(column) + { + TableName = tableName; + } + + private string TableName { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + var defValue = SchemaExtensions.ParseDefaultValue(Column.DefaultValue); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), TableName)}"); + text.AppendLine($"ADD CONSTRAINT {Context.GenerateConstraintName(Context.Schema.SchemaName(), TableName, ConstraintNameType.Default)} DEFAULT {defValue} FOR {Column.Name}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/DefaultDrop.cs b/Connected.Data/Schema/Sql/DefaultDrop.cs new file mode 100644 index 0000000..746ee6d --- /dev/null +++ b/Connected.Data/Schema/Sql/DefaultDrop.cs @@ -0,0 +1,52 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class DefaultDrop : ColumnTransaction + { + public DefaultDrop(ISchemaColumn column) : base(column) + { + } + + protected override async Task OnExecute() + { + if (string.IsNullOrWhiteSpace(DefaultName)) + return; + + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + text.AppendLine($"DROP CONSTRAINT {DefaultName};"); + + return text.ToString(); + } + } + + private string? DefaultName + { + get + { + if (Context.ExistingSchema is null) + return null; + + foreach (var constraint in Context.ExistingSchema.Descriptor.Constraints) + { + if (constraint.ConstraintType == ConstraintType.Default) + { + if (constraint.Columns.Count == 1 && string.Equals(constraint.Columns[0], Column.Name, StringComparison.OrdinalIgnoreCase)) + return constraint.Name; + } + } + + return null; + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ExistingSchema.cs b/Connected.Data/Schema/Sql/ExistingSchema.cs new file mode 100644 index 0000000..cdf5127 --- /dev/null +++ b/Connected.Data/Schema/Sql/ExistingSchema.cs @@ -0,0 +1,109 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ExistingSchema : ISchema + { + public ExistingSchema() + { + Columns = new(); + } + + public List Columns { get; } + + public string? Schema { get; set; } + + public string? Name { get; set; } + + public string? Type { get; set; } + + public bool Ignore { get; set; } + + public ObjectDescriptor? Descriptor { get; private set; } + + public async Task Load(SchemaExecutionContext context) + { + Name = context.Schema.Name; + Type = context.Schema.Type; + Schema = context.Schema.SchemaName(); + + Columns.AddRange(await new Columns(this).Execute(context)); + Descriptor = await new SpHelp().Execute(context); + + if (Columns.FirstOrDefault(f => string.Equals(f.Name, Descriptor.Identity.Identity, StringComparison.OrdinalIgnoreCase)) is ExistingColumn c) + c.IsIdentity = true; + + foreach (var index in Descriptor.Indexes) + { + foreach (var column in index.Columns) + { + if (Columns.FirstOrDefault(f => string.Equals(column, f.Name, StringComparison.OrdinalIgnoreCase)) is not ExistingColumn col) + continue; + + switch (index.Type) + { + case IndexType.Index: + col.IsIndex = true; + break; + case IndexType.Unique: + col.IsIndex = true; + col.IsUnique = true; + break; + case IndexType.PrimaryKey: + col.IsPrimaryKey = true; + col.IsIndex = true; + col.IsUnique = true; + break; + } + } + } + + foreach (var constraint in Descriptor.Constraints) + { + switch (constraint.ConstraintType) + { + case ConstraintType.Default: + if (Columns.FirstOrDefault(f => string.Equals(f.Name, constraint.Columns[0], StringComparison.OrdinalIgnoreCase)) is ExistingColumn column) + column.DefaultValue = constraint.DefaultValue; + break; + } + } + } + + public List Indexes + { + get + { + var result = new List(); + + foreach (var column in Columns) + { + var indexes = ResolveIndexes(column.Name); + + foreach (var index in indexes) + { + if (result.FirstOrDefault(f => string.Equals(f.Name, index.Name, StringComparison.OrdinalIgnoreCase)) is null) + result.Add(index); + } + } + + return result; + } + } + public List ResolveIndexes(string column) + { + var result = new List(); + + foreach (var index in Descriptor.Indexes) + { + if (index.IsReferencedBy(column)) + result.Add(index); + } + + return result; + } + + public bool Equals(ISchema? other) + { + throw new NotImplementedException(); + } + } +} diff --git a/Connected.Data/Schema/Sql/IdentityInsert.cs b/Connected.Data/Schema/Sql/IdentityInsert.cs new file mode 100644 index 0000000..341c2d1 --- /dev/null +++ b/Connected.Data/Schema/Sql/IdentityInsert.cs @@ -0,0 +1,34 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class IdentityInsert : TableTransaction + { + public IdentityInsert(string tableName, bool on) + { + On = on; + TableName = tableName; + } + + private string TableName { get; } + private bool On { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + var switchCommand = On ? "ON" : "OFF"; + + text.AppendLine($"SET IDENTITY_INSERT {Escape(Context.Schema.SchemaName(), TableName)} {switchCommand}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/IndexCreate.cs b/Connected.Data/Schema/Sql/IndexCreate.cs new file mode 100644 index 0000000..9365e22 --- /dev/null +++ b/Connected.Data/Schema/Sql/IndexCreate.cs @@ -0,0 +1,44 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class IndexCreate : TableTransaction + { + public IndexCreate(IndexDescriptor index) + { + Index = index; + } + + private IndexDescriptor Index { get; } + + protected override async Task OnExecute() + { + if (Index.Unique) + await new UniqueConstraintAdd(Index).Execute(Context); + else + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"CREATE NONCLUSTERED INDEX [{Context.GenerateConstraintName(Context.Schema.Schema, Context.Schema.Name, ConstraintNameType.Index)}] ON {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}("); + var comma = string.Empty; + + foreach (var column in Index.Columns) + { + text.AppendLine($"{comma}{Escape(column)} ASC"); + + comma = ","; + } + + text.AppendLine($") ON {Escape(SchemaExtensions.FileGroup)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/IndexDescriptor.cs b/Connected.Data/Schema/Sql/IndexDescriptor.cs new file mode 100644 index 0000000..1e4be7c --- /dev/null +++ b/Connected.Data/Schema/Sql/IndexDescriptor.cs @@ -0,0 +1,18 @@ +namespace Connected.Data.Schema.Sql +{ + internal class IndexDescriptor + { + private List? _columns; + + public bool Unique { get; set; } + + public string? Group { get; set; } + + public List Columns => _columns ??= new List(); + + public override string ToString() + { + return string.IsNullOrWhiteSpace(Group) ? Columns[0] : Group; + } + } +} diff --git a/Connected.Data/Schema/Sql/IndexDrop.cs b/Connected.Data/Schema/Sql/IndexDrop.cs new file mode 100644 index 0000000..c38b1eb --- /dev/null +++ b/Connected.Data/Schema/Sql/IndexDrop.cs @@ -0,0 +1,40 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class IndexDrop : TableTransaction + { + public IndexDrop(ObjectIndex index) + { + Index = index; + } + + private ObjectIndex Index { get; } + + protected override async Task OnExecute() + { + switch (Index.Type) + { + case IndexType.Index: + await Context.Execute(CommandText); + break; + case IndexType.Unique: + case IndexType.PrimaryKey: + await new ConstraintDrop(Index).Execute(Context); + break; + } + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"DROP INDEX {Index.Name} ON {Escape(Context.Schema.SchemaName(), Context.Schema.Name)};"); + + return text.ToString(); + } + } + } +} \ No newline at end of file diff --git a/Connected.Data/Schema/Sql/ObjectColumn.cs b/Connected.Data/Schema/Sql/ObjectColumn.cs new file mode 100644 index 0000000..708eef6 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectColumn.cs @@ -0,0 +1,16 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectColumn + { + public string Name { get; set; } + public string Type { get; set; } + public bool Computed { get; set; } + public int Length { get; set; } + public int Precision { get; set; } + public int Scale { get; set; } + public bool Nullable { get; set; } + public string TrimTrailingBlanks { get; set; } + public string FixedLenInSource { get; set; } + public string Collation { get; set; } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectConstraint.cs b/Connected.Data/Schema/Sql/ObjectConstraint.cs new file mode 100644 index 0000000..5e74971 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectConstraint.cs @@ -0,0 +1,61 @@ +namespace Connected.Data.Schema.Sql +{ + public enum ConstraintType + { + Unknown = 0, + Default = 1, + Unique = 2, + PrimaryKey = 3 + } + internal class ObjectConstraint + { + public string Type { get; set; } + public string Name { get; set; } + public string DeleteAction { get; set; } + public string UpdateAction { get; set; } + public string StatusEnabled { get; set; } + public string StatusForReplication { get; set; } + public string Keys { get; set; } + + public ConstraintType ConstraintType + { + get + { + if (Type.StartsWith("DEFAULT ")) + return ConstraintType.Default; + else if (Type.StartsWith("UNIQUE ")) + return ConstraintType.Unique; + else if (Type.StartsWith("PRIMARY KEY ")) + return ConstraintType.PrimaryKey; + else + return ConstraintType.Unknown; + } + } + + public List Columns + { + get + { + var result = new List(); + + switch (ConstraintType) + { + case ConstraintType.Default: + result.Add(Type.Split(' ')[^1].Trim()); + break; + case ConstraintType.Unique: + case ConstraintType.PrimaryKey: + var tokens = Keys.Split(','); + + foreach (var token in tokens) + result.Add(token); + break; + } + + return result; + } + } + + public string DefaultValue => ConstraintType == ConstraintType.Default && Keys.StartsWith("(") && Keys.EndsWith(")") ? Keys[1..^1] : Keys; + } +} diff --git a/Connected.Data/Schema/Sql/ObjectDescriptor.cs b/Connected.Data/Schema/Sql/ObjectDescriptor.cs new file mode 100644 index 0000000..3f63f54 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectDescriptor.cs @@ -0,0 +1,90 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectDescriptor + { + private ObjectMetaData _metaData = null; + private List _columns = null; + private ObjectIdentity _identity = null; + private ObjectRowGuid _rowGuid = null; + private ObjectFileGroup _fileGroup = null; + private List _indexes = null; + private List _constraints = null; + + public ObjectFileGroup FileGroup + { + get + { + if (_fileGroup == null) + _fileGroup = new ObjectFileGroup(); + + return _fileGroup; + } + } + + public ObjectRowGuid RowGuid + { + get + { + if (_rowGuid == null) + _rowGuid = new ObjectRowGuid(); + + return _rowGuid; + } + } + + public ObjectIdentity Identity + { + get + { + if (_identity == null) + _identity = new ObjectIdentity(); + + return _identity; + } + } + + public ObjectMetaData MetaData + { + get + { + if (_metaData == null) + _metaData = new ObjectMetaData(); + + return _metaData; + } + } + + public List Columns + { + get + { + if (_columns == null) + _columns = new List(); + + return _columns; + } + } + + public List Indexes + { + get + { + if (_indexes == null) + _indexes = new List(); + + return _indexes; + } + } + + public List Constraints + { + get + { + if (_constraints == null) + _constraints = new List(); + + return _constraints; + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectFileGroup.cs b/Connected.Data/Schema/Sql/ObjectFileGroup.cs new file mode 100644 index 0000000..3508066 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectFileGroup.cs @@ -0,0 +1,7 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectFileGroup + { + public string FileGroup { get; set; } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectIdentity.cs b/Connected.Data/Schema/Sql/ObjectIdentity.cs new file mode 100644 index 0000000..cb568ba --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectIdentity.cs @@ -0,0 +1,10 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectIdentity + { + public string Identity { get; set; } + public int Seed { get; set; } + public int Increment { get; set; } + public bool NotForReplication { get; set; } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectIndex.cs b/Connected.Data/Schema/Sql/ObjectIndex.cs new file mode 100644 index 0000000..d6635b8 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectIndex.cs @@ -0,0 +1,53 @@ +namespace Connected.Data.Schema.Sql +{ + public enum IndexType + { + Index = 1, + Unique = 2, + PrimaryKey = 3 + } + internal class ObjectIndex + { + public string Name { get; set; } + public string Description { get; set; } + public string Keys { get; set; } + + public IndexType Type + { + get + { + var tokens = Description.Split(','); + var result = IndexType.Index; + + foreach (var token in tokens) + { + if (token.Trim().Contains("primary key", StringComparison.OrdinalIgnoreCase)) + return IndexType.PrimaryKey; + else if (string.Compare(token.Trim(), "unique", true) == 0) + result = IndexType.Unique; + } + + return result; + } + } + + public bool IsReferencedBy(string column) + { + return Columns.Contains(column, StringComparer.OrdinalIgnoreCase); + } + + public List Columns + { + get + { + var result = new List(); + var tokens = Keys.Split(','); + + foreach (var token in tokens) + result.Add(token.Trim()); + + return result; + } + } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectMetaData.cs b/Connected.Data/Schema/Sql/ObjectMetaData.cs new file mode 100644 index 0000000..1351980 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectMetaData.cs @@ -0,0 +1,10 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectMetaData + { + public string Name { get; set; } + public string Owner { get; set; } + public string Type { get; set; } + public DateTime Created { get; set; } + } +} diff --git a/Connected.Data/Schema/Sql/ObjectRowGuid.cs b/Connected.Data/Schema/Sql/ObjectRowGuid.cs new file mode 100644 index 0000000..471ab55 --- /dev/null +++ b/Connected.Data/Schema/Sql/ObjectRowGuid.cs @@ -0,0 +1,7 @@ +namespace Connected.Data.Schema.Sql +{ + internal class ObjectRowGuid + { + public string RowGuidCol { get; set; } + } +} diff --git a/Connected.Data/Schema/Sql/PrimaryKeyAdd.cs b/Connected.Data/Schema/Sql/PrimaryKeyAdd.cs new file mode 100644 index 0000000..1622644 --- /dev/null +++ b/Connected.Data/Schema/Sql/PrimaryKeyAdd.cs @@ -0,0 +1,31 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class PrimaryKeyAdd : ColumnTransaction + { + public PrimaryKeyAdd(ISchemaColumn column) : base(column) + { + + } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + text.AppendLine($"ADD CONSTRAINT {Context.GenerateConstraintName(Context.Schema.SchemaName(), Context.Schema.Name, ConstraintNameType.PrimaryKey)}"); + text.AppendLine($"PRIMARY KEY CLUSTERED ({Escape(Column.Name)}) ON {Escape(SchemaExtensions.FileGroup)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/PrimaryKeyRemove.cs b/Connected.Data/Schema/Sql/PrimaryKeyRemove.cs new file mode 100644 index 0000000..ba4d9f6 --- /dev/null +++ b/Connected.Data/Schema/Sql/PrimaryKeyRemove.cs @@ -0,0 +1,22 @@ +namespace Connected.Data.Schema.Sql +{ + internal class PrimaryKeyRemove : ColumnTransaction + { + public PrimaryKeyRemove(ExistingSchema existing, ISchemaColumn column) : base(column) + { + Existing = existing; + } + + private ExistingSchema Existing { get; } + + protected override async Task OnExecute() + { + if (Existing.Indexes.FirstOrDefault(f => f.Type == IndexType.PrimaryKey) is ObjectIndex constraint) + { + await new ConstraintDrop(constraint).Execute(Context); + + Existing.Indexes.Remove(constraint); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/SchemaExecutionContext.cs b/Connected.Data/Schema/Sql/SchemaExecutionContext.cs new file mode 100644 index 0000000..01d7097 --- /dev/null +++ b/Connected.Data/Schema/Sql/SchemaExecutionContext.cs @@ -0,0 +1,126 @@ +using Connected.Data.Sql; +using Connected.Entities.Storage; +using System.Data; + +namespace Connected.Data.Schema.Sql; + +internal class SchemaExecutionContext +{ + private ExistingSchema _existingSchema; + public SchemaExecutionContext(IStorageProvider storage, ISchema schema, string connectionString) + { + Storage = storage; + Schema = schema; + ConnectionString = connectionString; + Constraints = new(); + } + + public ExistingSchema ExistingSchema + { + get => _existingSchema; set + { + _existingSchema = value; + + if (_existingSchema is null) + return; + + foreach (var index in _existingSchema.Descriptor.Constraints) + { + switch (index.ConstraintType) + { + case ConstraintType.Default: + AddConstraint(ConstraintNameType.Default, index.Name); + break; + case ConstraintType.PrimaryKey: + AddConstraint(ConstraintNameType.PrimaryKey, index.Name); + break; + case ConstraintType.Unique: + AddConstraint(ConstraintNameType.Index, index.Name); + break; + } + } + + } + } + + public IStorageProvider Storage { get; } + public ISchema Schema { get; } + private string ConnectionString { get; } + public Dictionary> Constraints { get; } + + public async Task Execute(string commandText) + { + await Storage.Open().Select(new SchemaStorageArgs(new StorageOperation { CommandText = commandText }, typeof(SqlDataConnection), ConnectionString)); + } + + public async Task Select(string commandText) + { + return await Storage.Open().Select(new SchemaStorageArgs(new StorageOperation { CommandText = commandText }, typeof(SqlDataConnection), ConnectionString)); + } + + public async Task OpenReader(IStorageOperation operation) + { + var readers = await Storage.Open().OpenReaders(new SchemaStorageArgs(operation, typeof(SqlDataConnection), ConnectionString)); + + return readers[0]; + } + + private void AddConstraint(ConstraintNameType type, string name) + { + if (!Constraints.TryGetValue(type, out List? existing)) + { + existing = new List(); + + Constraints.Add(type, existing); + } + + if (ConstraintNameExists(name)) + existing.Add(name); + } + + public string GenerateConstraintName(string schema, string tableName, ConstraintNameType type) + { + var index = 0; + + while (true) + { + var value = $"{ConstraintPrefix(type)}_{schema.ToLowerInvariant()}_{tableName}"; + + if (index > 0) + value = $"{value}_{index}"; + + if (!ConstraintNameExists(value)) + { + AddConstraint(type, value); + return value; + } + + index++; + } + } + + private bool ConstraintNameExists(string value) + { + foreach (var key in Constraints) + { + foreach (var item in key.Value) + { + if (item.Contains(value, StringComparison.OrdinalIgnoreCase)) + return true; + } + } + + return false; + } + + private static string ConstraintPrefix(ConstraintNameType type) + { + return type switch + { + ConstraintNameType.Default => "DF", + ConstraintNameType.PrimaryKey => "PK", + ConstraintNameType.Index => "IX", + _ => "IX" + }; + } +} diff --git a/Connected.Data/Schema/Sql/SchemaExists.cs b/Connected.Data/Schema/Sql/SchemaExists.cs new file mode 100644 index 0000000..89b600e --- /dev/null +++ b/Connected.Data/Schema/Sql/SchemaExists.cs @@ -0,0 +1,25 @@ +using Connected.Entities.Storage; + +namespace Connected.Data.Schema.Sql; + +internal class SchemaExists : SynchronizationQuery +{ + public SchemaExists(string name) + { + Name = name; + } + private string Name { get; } + + protected override async Task OnExecute() + { + if (string.IsNullOrWhiteSpace(Name)) + return true; + + var rdr = await Context.OpenReader(new StorageOperation { CommandText = $"SELECT * FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '{Name}'" }); + var result = rdr.Read(); + + rdr.Close(); + + return result; + } +} diff --git a/Connected.Data/Schema/Sql/SchemaExtensions.cs b/Connected.Data/Schema/Sql/SchemaExtensions.cs new file mode 100644 index 0000000..8e14744 --- /dev/null +++ b/Connected.Data/Schema/Sql/SchemaExtensions.cs @@ -0,0 +1,233 @@ +using System.Data; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema.Sql; + +internal static class SchemaExtensions +{ + public const string FileGroup = "PRIMARY"; + public static ITable? Find(this List tables, string schema, string name) + { + return tables.FirstOrDefault(f => string.Equals(schema, f.Schema, StringComparison.OrdinalIgnoreCase) && string.Equals(name, f.Name, StringComparison.OrdinalIgnoreCase)); + } + + public static ITableColumn? FindColumn(this ITable table, string name) + { + return table.Columns.FirstOrDefault(f => string.Equals(name, f.Name, StringComparison.OrdinalIgnoreCase)); + } + + public static ITable? FindPrimaryKeyTable(this IDatabase database, string name) + { + foreach (var i in database.Tables) + { + foreach (var j in i.Columns) + { + foreach (var k in j.Constraints) + { + if (string.Equals(k.Type, "PRIMARY KEY", StringComparison.OrdinalIgnoreCase) && string.Equals(k.Name, name, StringComparison.OrdinalIgnoreCase)) + return i; + } + } + } + + return null; + } + + public static ITableColumn? FindPrimaryKeyColumn(this IDatabase database, string name) + { + foreach (var i in database.Tables) + { + foreach (var j in i.Columns) + { + foreach (var k in j.Constraints) + { + if (string.Equals(k.Type, "PRIMARY KEY", StringComparison.OrdinalIgnoreCase) && string.Equals(k.Name, name, StringComparison.OrdinalIgnoreCase)) + return j; + } + } + } + + return null; + } + + public static ITableColumn? ResolvePrimaryKeyColumn(this ITable table) + { + foreach (var i in table.Columns) + { + foreach (var j in i.Constraints) + { + if (string.Equals(j.Type, "PRIMARY KEY", StringComparison.OrdinalIgnoreCase)) + return i; + } + } + + return null; + } + + public static ITableConstraint? ResolvePrimaryKey(this ITable table) + { + foreach (var i in table.Columns) + { + foreach (var j in i.Constraints) + { + if (string.Equals(j.Type, "PRIMARY KEY", StringComparison.OrdinalIgnoreCase)) + return j; + } + } + + return null; + } + + public static List ResolveDefaults(this ITable table) + { + var r = new List(); + + foreach (var i in table.Columns) + { + if (!string.IsNullOrWhiteSpace(i.DefaultValue)) + r.Add(i); + } + + return r; + } + + public static List ResolveUniqueConstraints(this ITable table) + { + var r = new List(); + + foreach (var i in table.Columns) + { + foreach (var j in i.Constraints) + { + if (string.Equals(j.Type, "UNIQUE", StringComparison.OrdinalIgnoreCase) && r.FirstOrDefault(f => string.Equals(f.Name, j.Name, StringComparison.OrdinalIgnoreCase)) is null) + r.Add(j); + } + } + return r; + } + + public static ITableColumn? FindUniqueConstraintColumn(this IDatabase database, string name) + { + foreach (var table in database.Tables) + { + foreach (var column in table.Columns) + { + foreach (var constraint in column.Constraints) + { + if (string.Equals(constraint.Type, "UNIQUE", StringComparison.OrdinalIgnoreCase) && string.Equals(constraint.Name, name, StringComparison.OrdinalIgnoreCase)) + return column; + } + } + } + + return null; + } + + public static T GetValue(this IDataReader r, string fieldName, T defaultValue) + { + var idx = r.GetOrdinal(fieldName); + + if (idx == -1) + return defaultValue; + + if (r.IsDBNull(idx)) + return defaultValue; + + return (T)Convert.ChangeType(r.GetValue(idx), typeof(T)); + } + + public static string SchemaName(this ISchema schema) + { + return string.IsNullOrWhiteSpace(schema.Schema) ? SchemaAttribute.DefaultSchema : schema.Schema; + } + + public static string ParseDefaultValue(string value) + { + if (string.IsNullOrEmpty(value)) + return value; + + if (value.StartsWith("N'")) + return value; + + var defValue = $"N'{value}'"; + + if (value.Length > 1) + { + var last = value.Trim()[^1]; + var prev = value.Trim()[0..^1].Trim()[^1]; + + if (last == ')' && prev == '(') + defValue = value; + } + + return defValue; + } + + public static DbType ToDbType(string value) + { + if (string.Equals(value, "bigint", StringComparison.OrdinalIgnoreCase)) + return DbType.Int64; + else if (string.Equals(value, "binary", StringComparison.OrdinalIgnoreCase)) + return DbType.Binary; + else if (string.Equals(value, "bit", StringComparison.OrdinalIgnoreCase)) + return DbType.Boolean; + else if (string.Equals(value, "char", StringComparison.OrdinalIgnoreCase)) + return DbType.AnsiStringFixedLength; + else if (string.Equals(value, "date", StringComparison.OrdinalIgnoreCase)) + return DbType.Date; + else if (string.Equals(value, "datetime", StringComparison.OrdinalIgnoreCase)) + return DbType.DateTime; + else if (string.Equals(value, "datetime2", StringComparison.OrdinalIgnoreCase)) + return DbType.DateTime2; + else if (string.Equals(value, "datetimeoffset", StringComparison.OrdinalIgnoreCase)) + return DbType.DateTimeOffset; + else if (string.Equals(value, "decimal", StringComparison.OrdinalIgnoreCase)) + return DbType.Decimal; + else if (string.Equals(value, "float", StringComparison.OrdinalIgnoreCase)) + return DbType.Double; + else if (string.Equals(value, "geography", StringComparison.OrdinalIgnoreCase)) + return DbType.Object; + else if (string.Equals(value, "hierarchyid", StringComparison.OrdinalIgnoreCase)) + return DbType.Object; + else if (string.Equals(value, "image", StringComparison.OrdinalIgnoreCase)) + return DbType.Binary; + else if (string.Equals(value, "int", StringComparison.OrdinalIgnoreCase)) + return DbType.Int32; + else if (string.Equals(value, "money", StringComparison.OrdinalIgnoreCase)) + return DbType.Currency; + else if (string.Equals(value, "nchar", StringComparison.OrdinalIgnoreCase)) + return DbType.StringFixedLength; + else if (string.Equals(value, "ntext", StringComparison.OrdinalIgnoreCase)) + return DbType.String; + else if (string.Equals(value, "numeric", StringComparison.OrdinalIgnoreCase)) + return DbType.VarNumeric; + else if (string.Equals(value, "nvarchar", StringComparison.OrdinalIgnoreCase)) + return DbType.String; + else if (string.Equals(value, "real", StringComparison.OrdinalIgnoreCase)) + return DbType.Single; + else if (string.Equals(value, "smalldatetime", StringComparison.OrdinalIgnoreCase)) + return DbType.DateTime; + else if (string.Equals(value, "smallmoney", StringComparison.OrdinalIgnoreCase)) + return DbType.Currency; + else if (string.Equals(value, "sql_variant", StringComparison.OrdinalIgnoreCase)) + return DbType.Object; + else if (string.Equals(value, "text", StringComparison.OrdinalIgnoreCase)) + return DbType.String; + else if (string.Equals(value, "time", StringComparison.OrdinalIgnoreCase)) + return DbType.Time; + else if (string.Equals(value, "timestamp", StringComparison.OrdinalIgnoreCase)) + return DbType.Binary; + else if (string.Equals(value, "tinyint", StringComparison.OrdinalIgnoreCase)) + return DbType.Byte; + else if (string.Equals(value, "uniqueidentifier", StringComparison.OrdinalIgnoreCase)) + return DbType.Guid; + else if (string.Equals(value, "varbinary", StringComparison.OrdinalIgnoreCase)) + return DbType.Binary; + else if (string.Equals(value, "varchar", StringComparison.OrdinalIgnoreCase)) + return DbType.AnsiString; + else if (string.Equals(value, "xml", StringComparison.OrdinalIgnoreCase)) + return DbType.Xml; + else + return DbType.String; + } +} \ No newline at end of file diff --git a/Connected.Data/Schema/Sql/SchemaStorageArgs.cs b/Connected.Data/Schema/Sql/SchemaStorageArgs.cs new file mode 100644 index 0000000..4a00027 --- /dev/null +++ b/Connected.Data/Schema/Sql/SchemaStorageArgs.cs @@ -0,0 +1,15 @@ +using Connected.Entities.Storage; + +namespace Connected.Data.Schema.Sql; + +internal sealed class SchemaStorageArgs : StorageContextArgs, ISchemaSynchronizationContext +{ + public SchemaStorageArgs(IStorageOperation operation, Type connectionType, string connectionString) : base(operation) + { + ConnectionType = connectionType; + ConnectionString = connectionString; + } + + public Type ConnectionType { get; } + public string ConnectionString { get; } +} diff --git a/Connected.Data/Schema/Sql/SchemaSynchronize.cs b/Connected.Data/Schema/Sql/SchemaSynchronize.cs new file mode 100644 index 0000000..54cbbb9 --- /dev/null +++ b/Connected.Data/Schema/Sql/SchemaSynchronize.cs @@ -0,0 +1,19 @@ +namespace Connected.Data.Schema.Sql +{ + internal class SchemaSynchronize : SynchronizationTransaction + { + protected override async Task OnExecute() + { + if (string.IsNullOrWhiteSpace(Context.Schema.Schema)) + return; + + if (!await new SchemaExists(Context.Schema.Schema).Execute(Context)) + await CreateSchema(); + } + + private async Task CreateSchema() + { + await Context.Execute($"CREATE SCHEMA {Context.Schema.Schema};"); + } + } +} diff --git a/Connected.Data/Schema/Sql/SpHelp.cs b/Connected.Data/Schema/Sql/SpHelp.cs new file mode 100644 index 0000000..3bfe615 --- /dev/null +++ b/Connected.Data/Schema/Sql/SpHelp.cs @@ -0,0 +1,145 @@ +using Connected.Data.Storage; +using Connected.Entities.Storage; +using Connected.Interop; +using System.Data; + +namespace Connected.Data.Schema.Sql; + +internal class SpHelp : SynchronizationQuery +{ + private readonly ObjectDescriptor _descriptor; + + public SpHelp() + { + _descriptor = new(); + } + + private ObjectDescriptor Result => _descriptor; + + protected override async Task OnExecute() + { + var operation = new StorageOperation { CommandText = "sp_help", CommandType = CommandType.Text }; + + operation.AddParameter(new StorageParameter + { + Name = "@objname", + Type = DbType.String, + Value = Escape(Context.Schema.SchemaName(), Context.Schema.Name) + } + ); + + var rdr = await Context.OpenReader(operation); + + try + { + ReadMetadata(rdr); + ReadColumns(rdr); + ReadIdentity(rdr); + ReadRowGuid(rdr); + ReadFileGroup(rdr); + ReadIndexes(rdr); + ReadConstraints(rdr); + } + finally + { + rdr.Close(); + } + + return Result; + } + + private void ReadMetadata(IDataReader rdr) + { + if (rdr.Read()) + { + Result.MetaData.Name = rdr.GetValue("Name", string.Empty); + Result.MetaData.Created = rdr.GetValue("Created_datetime", DateTime.MinValue); + Result.MetaData.Owner = rdr.GetValue("Owner", string.Empty); + Result.MetaData.Type = rdr.GetValue("Type", string.Empty); + } + } + + private void ReadColumns(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Columns.Add(new ObjectColumn + { + Collation = rdr.GetValue("Collation", string.Empty), + Computed = !string.Equals(rdr.GetValue("Computed", string.Empty), "no", StringComparison.OrdinalIgnoreCase), + FixedLenInSource = rdr.GetValue("FixedLenNullInSource", string.Empty), + Length = rdr.GetValue("Length", 0), + Name = rdr.GetValue("Column_name", string.Empty), + Nullable = !string.Equals(rdr.GetValue("Nullable", string.Empty), "no", StringComparison.OrdinalIgnoreCase), + Precision = TypeConversion.Convert(rdr.GetValue("Prec", string.Empty).Trim()), + Scale = TypeConversion.Convert(rdr.GetValue("Scale", string.Empty).Trim()), + TrimTrailingBlanks = rdr.GetValue("TrimTrailingBlanks", string.Empty), + Type = rdr.GetValue("Type", string.Empty) + }); + } + } + + private void ReadIdentity(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + { + Result.Identity.Identity = rdr.GetValue("Identity", string.Empty); + Result.Identity.Increment = rdr.GetValue("Increment", 0); + Result.Identity.NotForReplication = rdr.GetValue("Not For Replication", 0) != 0; + } + } + + private void ReadRowGuid(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + Result.RowGuid.RowGuidCol = rdr.GetValue("RowGuidCol", string.Empty); + } + + private void ReadFileGroup(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + Result.FileGroup.FileGroup = rdr.GetValue("Data_located_on_filegroup", string.Empty); + } + + private void ReadIndexes(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Indexes.Add(new ObjectIndex + { + Description = rdr.GetValue("index_description", string.Empty), + Keys = rdr.GetValue("index_keys", string.Empty), + Name = rdr.GetValue("index_name", string.Empty) + }); + } + } + + private void ReadConstraints(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Constraints.Add(new ObjectConstraint + { + DeleteAction = rdr.GetValue("delete_action", string.Empty), + Keys = rdr.GetValue("constraint_keys", string.Empty), + Name = rdr.GetValue("constraint_name", string.Empty), + StatusEnabled = rdr.GetValue("status_enabled", string.Empty), + StatusForReplication = rdr.GetValue("status_for_replication", string.Empty), + Type = rdr.GetValue("constraint_type", string.Empty), + UpdateAction = rdr.GetValue("update_action", string.Empty) + }); + } + } +} diff --git a/Connected.Data/Schema/Sql/SqlSchemaMiddleware.cs b/Connected.Data/Schema/Sql/SqlSchemaMiddleware.cs new file mode 100644 index 0000000..6a3f805 --- /dev/null +++ b/Connected.Data/Schema/Sql/SqlSchemaMiddleware.cs @@ -0,0 +1,86 @@ +using Connected.Annotations; +using Connected.Configuration; +using Connected.Data.Sharding; +using Connected.Data.Sql; +using Connected.Entities; +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using Connected.Middleware; + +namespace Connected.Data.Schema.Sql; + +internal enum ConstraintNameType +{ + Index = 1, + PrimaryKey = 2, + Default = 3 +} + +[Priority(0)] +internal sealed class SqlSchemaMiddleware : MiddlewareComponent, ISchemaMiddleware +{ + public SqlSchemaMiddleware(IMiddlewareService middleware, IStorageProvider storage, IConfigurationService configuration) + { + Middleware = middleware; + Storage = storage; + Configuration = configuration; + } + + private IMiddlewareService Middleware { get; } + private IStorageProvider Storage { get; } + public IConfigurationService Configuration { get; } + + public Type ConnectionType => typeof(SqlDataConnection); + + public string DefaultConnectionString => Configuration.Storage.Databases.DefaultConnectionString; + + public async Task IsEntitySupported(Type entityType) + { + await Task.CompletedTask; + /* + * By default, all entities are supported by this middleware. + */ + return entityType.IsAssignableTo(typeof(IEntity)); + } + + public async Task Synchronize(Type entity, ISchema schema) + { + await Synchronize(schema, DefaultConnectionString); + /* + * First query all sharding middleware because we must perform synchronization + * on all nodes. + */ + if (await ResolveShardingMiddleware(entity) is IShardingMiddleware sharding) + { + foreach (var node in await sharding.ProvideNodes(entity)) + await Synchronize(schema, node.ConnectionString); + } + } + + private async Task Synchronize(ISchema schema, string connectionString) + { + var args = new SchemaExecutionContext(Storage, schema, connectionString); + /* + * Sinchronize schema object first. + */ + await new SchemaSynchronize().Execute(args); + /* + * Only tables are supported + */ + if (string.IsNullOrWhiteSpace(schema.Type) || string.Equals(schema.Type, SchemaAttribute.SchemaTypeTable, StringComparison.OrdinalIgnoreCase)) + await new TableSynchronize().Execute(args); + } + + private async Task ResolveShardingMiddleware(Type entityType) + { + var all = await Middleware.Query(); + + foreach (var middleware in all) + { + if (middleware.SupportsEntity(entityType)) + return middleware; + } + + return null; + } +} diff --git a/Connected.Data/Schema/Sql/SynchronizationCommand.cs b/Connected.Data/Schema/Sql/SynchronizationCommand.cs new file mode 100644 index 0000000..badb2af --- /dev/null +++ b/Connected.Data/Schema/Sql/SynchronizationCommand.cs @@ -0,0 +1,20 @@ +namespace Connected.Data.Schema.Sql +{ + internal abstract class SynchronizationCommand + { + public static string Escape(string value) + { + return $"[{Unescape(value)}]"; + } + + public static string Unescape(string value) + { + return value.TrimStart('[').TrimEnd(']'); + } + + public static string Escape(string schema, string name) + { + return $"{Escape(schema)}.{Escape(name)}"; + } + } +} diff --git a/Connected.Data/Schema/Sql/SynchronizationQuery.cs b/Connected.Data/Schema/Sql/SynchronizationQuery.cs new file mode 100644 index 0000000..28b054c --- /dev/null +++ b/Connected.Data/Schema/Sql/SynchronizationQuery.cs @@ -0,0 +1,21 @@ +namespace Connected.Data.Schema.Sql +{ + internal abstract class SynchronizationQuery : SynchronizationCommand + { + protected SchemaExecutionContext Context { get; private set; } + + public async Task Execute(SchemaExecutionContext context) + { + Context = context; + + return await OnExecute(); + } + + protected virtual async Task OnExecute() + { + await Task.CompletedTask; + + return default; + } + } +} diff --git a/Connected.Data/Schema/Sql/SynchronizationTransaction.cs b/Connected.Data/Schema/Sql/SynchronizationTransaction.cs new file mode 100644 index 0000000..cf8942d --- /dev/null +++ b/Connected.Data/Schema/Sql/SynchronizationTransaction.cs @@ -0,0 +1,19 @@ +namespace Connected.Data.Schema.Sql +{ + internal abstract class SynchronizationTransaction : SynchronizationCommand + { + protected SchemaExecutionContext Context { get; private set; } + + public async Task Execute(SchemaExecutionContext context) + { + Context = context; + + await OnExecute(); + } + + protected virtual async Task OnExecute() + { + await Task.CompletedTask; + } + } +} diff --git a/Connected.Data/Schema/Sql/TableAlter.cs b/Connected.Data/Schema/Sql/TableAlter.cs new file mode 100644 index 0000000..965ed55 --- /dev/null +++ b/Connected.Data/Schema/Sql/TableAlter.cs @@ -0,0 +1,132 @@ +namespace Connected.Data.Schema.Sql +{ + internal class TableAlter : TableSynchronize + { + public TableAlter(ExistingSchema schema) + { + Existing = schema; + } + + private ExistingSchema Existing { get; } + + protected override async Task OnExecute() + { + var dropped = new List(); + + foreach (var index in Existing.Indexes) + { + if (!ColumnsMatched(index)) + { + await new IndexDrop(index).Execute(Context); + dropped.Add(index); + } + } + + foreach (var drop in dropped) + Existing.Indexes.Remove(drop); + + foreach (var existingColumn in Existing.Columns) + { + if (Context.Schema.Columns.FirstOrDefault(f => string.Equals(f.Name, existingColumn.Name, StringComparison.OrdinalIgnoreCase)) is not ISchemaColumn column) + await new ColumnDrop(existingColumn, Existing).Execute(Context); + else + await new ColumnAlter(column, Existing, existingColumn).Execute(Context); + } + + var indexes = ParseIndexes(Context.Schema); + + foreach (var index in indexes) + { + if (!IndexExists(index)) + await new IndexCreate(index).Execute(Context); + } + } + + private bool IndexExists(IndexDescriptor index) + { + var existingIndexes = Existing.Indexes.Where(f => f.Type != IndexType.PrimaryKey); + + foreach (var existingIndex in existingIndexes) + { + if (index.Unique && existingIndex.Type != IndexType.Unique) + continue; + + if (!index.Unique && existingIndex.Type == IndexType.Unique) + continue; + + var cols = index.Columns.OrderBy(f => f); + var existingCols = existingIndex.Columns.OrderBy(f => f); + + if (cols.Count() != existingCols.Count()) + continue; + + for (var i = 0; i < cols.Count(); i++) + { + if (!string.Equals(cols.ElementAt(i), existingCols.ElementAt(i), StringComparison.OrdinalIgnoreCase)) + break; + } + + return true; + } + + return false; + } + + private bool ColumnsMatched(ObjectIndex index) + { + if (index.Columns.Count == 1) + return ColumnMatched(index); + + var indexGroup = string.Empty; + var columns = new List(); + + foreach (var column in Context.Schema.Columns) + { + if (index.Columns.Contains(column.Name, StringComparer.OrdinalIgnoreCase)) + { + if (string.IsNullOrWhiteSpace(column.Index)) + return false; + + if (!string.Equals(indexGroup, column.Index, StringComparison.OrdinalIgnoreCase)) + return false; + + if (string.IsNullOrWhiteSpace(indexGroup)) + indexGroup = column.Index; + + columns.Add(column); + } + } + + foreach (var column in Context.Schema.Columns) + { + if (string.Equals(column.Index, indexGroup, StringComparison.OrdinalIgnoreCase) && !columns.Contains(column) && column.IsIndex) + columns.Add(column); + } + + if (index.Columns.Count != columns.Count) + return false; + + foreach (var column in columns.OrderBy(f => f.Name)) + { + if (!index.Columns.Contains(column.Name, StringComparer.OrdinalIgnoreCase)) + return false; + } + + return true; + } + + private bool ColumnMatched(ObjectIndex index) + { + if (Context.Schema.Columns.FirstOrDefault(f => string.Equals(f.Name, index.Columns[0], StringComparison.OrdinalIgnoreCase)) is not ISchemaColumn column) + return false; + + if (!column.IsIndex) + return false; + + if (index.Type == IndexType.Unique && !column.IsUnique) + return false; + + return true; + } + } +} diff --git a/Connected.Data/Schema/Sql/TableCreate.cs b/Connected.Data/Schema/Sql/TableCreate.cs new file mode 100644 index 0000000..84ef7b9 --- /dev/null +++ b/Connected.Data/Schema/Sql/TableCreate.cs @@ -0,0 +1,83 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class TableCreate : TableTransaction + { + public TableCreate(bool temporary) + { + Temporary = temporary; + + if (Temporary) + TemporaryName = $"T{Guid.NewGuid().ToString().Replace("-", string.Empty)}"; + } + + private bool Temporary { get; } + + public string TemporaryName { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + + if (!Temporary) + { + await ExecutePrimaryKey(); + await ExecuteDefaults(); + await ExecuteIndexes(); + } + } + + private async Task ExecutePrimaryKey() + { + var primaryKey = Context.Schema.Columns.FirstOrDefault(f => f.IsPrimaryKey); + + if (primaryKey is not null) + await new PrimaryKeyAdd(primaryKey).Execute(Context); + } + + private async Task ExecuteDefaults() + { + var name = Temporary ? TemporaryName : Context.Schema.Name; + + foreach (var column in Context.Schema.Columns) + { + if (!string.IsNullOrWhiteSpace(column.DefaultValue)) + await new DefaultAdd(column, name).Execute(Context); + } + } + + private async Task ExecuteIndexes() + { + var indexes = ParseIndexes(Context.Schema); + + foreach (var index in indexes) + await new IndexCreate(index).Execute(Context); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + var name = Temporary ? TemporaryName : Context.Schema.Name; + + text.AppendLine($"CREATE TABLE {Escape(Context.Schema.SchemaName(), name)}"); + text.AppendLine("("); + var comma = string.Empty; + + for (var i = 0; i < Context.Schema.Columns.Count; i++) + { + text.AppendLine($"{comma} {CreateColumnCommandText(Context.Schema.Columns[i])}"); + + comma = ","; + } + + text.AppendLine(");"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/TableDrop.cs b/Connected.Data/Schema/Sql/TableDrop.cs new file mode 100644 index 0000000..7746e10 --- /dev/null +++ b/Connected.Data/Schema/Sql/TableDrop.cs @@ -0,0 +1,24 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class TableDrop : TableTransaction + { + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"DROP TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/TableExists.cs b/Connected.Data/Schema/Sql/TableExists.cs new file mode 100644 index 0000000..e3dce5d --- /dev/null +++ b/Connected.Data/Schema/Sql/TableExists.cs @@ -0,0 +1,24 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class TableExists : SynchronizationQuery + { + protected override async Task OnExecute() + { + return (await Context.Select(CommandText)).Result; + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"IF (EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{Unescape(Context.Schema.SchemaName())}' AND TABLE_NAME = '{Unescape(Context.Schema.Name)}')) SELECT 1 as result ELSE SELECT 0 as result"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/TableRecreate.cs b/Connected.Data/Schema/Sql/TableRecreate.cs new file mode 100644 index 0000000..89e89b4 --- /dev/null +++ b/Connected.Data/Schema/Sql/TableRecreate.cs @@ -0,0 +1,74 @@ +namespace Connected.Data.Schema.Sql +{ + internal class TableRecreate : TableTransaction + { + public TableRecreate(ExistingSchema existing) + { + Existing = existing; + } + + private ExistingSchema Existing { get; } + + protected override async Task OnExecute() + { + var add = new TableCreate(true); + + await add.Execute(Context); + + await ExecuteDefaults(add.TemporaryName); + + if (HasIdentity) + await new IdentityInsert(add.TemporaryName, true).Execute(Context); + + await new DataCopy(Existing, add.TemporaryName).Execute(Context); + + if (HasIdentity) + await new IdentityInsert(add.TemporaryName, false).Execute(Context); + + await new TableDrop().Execute(Context); + await new TableRename(add.TemporaryName).Execute(Context); + + await ExecutePrimaryKey(); + await ExecuteIndexes(); + } + + private bool HasIdentity + { + get + { + foreach (var column in Context.Schema.Columns) + { + if (column.IsPrimaryKey && column.IsIdentity) + return true; + } + + return false; + } + } + + private async Task ExecutePrimaryKey() + { + var pk = Context.Schema.Columns.FirstOrDefault(f => f.IsPrimaryKey); + + if (pk != null) + await new PrimaryKeyAdd(pk).Execute(Context); + } + + private async Task ExecuteDefaults(string tableName) + { + foreach (var column in Context.Schema.Columns) + { + if (!string.IsNullOrWhiteSpace(column.DefaultValue)) + await new DefaultAdd(column, tableName).Execute(Context); + } + } + + private async Task ExecuteIndexes() + { + var indexes = ParseIndexes(Context.Schema); + + foreach (var index in indexes) + await new IndexCreate(index).Execute(Context); + } + } +} diff --git a/Connected.Data/Schema/Sql/TableRename.cs b/Connected.Data/Schema/Sql/TableRename.cs new file mode 100644 index 0000000..5fecc52 --- /dev/null +++ b/Connected.Data/Schema/Sql/TableRename.cs @@ -0,0 +1,31 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class TableRename : TableTransaction + { + public TableRename(string temporaryName) + { + TemporaryName = temporaryName; + } + + private string TemporaryName { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"EXECUTE sp_rename N'{Unescape(Context.Schema.SchemaName())}.{Unescape(TemporaryName)}', N'{Unescape(Context.Schema.Name)}', 'OBJECT'"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Schema/Sql/TableSynchronize.cs b/Connected.Data/Schema/Sql/TableSynchronize.cs new file mode 100644 index 0000000..56ff99a --- /dev/null +++ b/Connected.Data/Schema/Sql/TableSynchronize.cs @@ -0,0 +1,88 @@ +namespace Connected.Data.Schema.Sql +{ + internal class TableSynchronize : TableTransaction + { + private ExistingSchema? _existingSchema; + + private bool TableExists { get; set; } + + protected override async Task OnExecute() + { + TableExists = await new TableExists().Execute(Context); + + if (!TableExists) + { + await new TableCreate(false).Execute(Context); + return; + } + + _existingSchema = new(); + + await _existingSchema.Load(Context); + + Context.ExistingSchema = ExistingSchema; + + if (ShouldRecreate) + await new TableRecreate(ExistingSchema).Execute(Context); + else if (ShouldAlter) + await new TableAlter(ExistingSchema).Execute(Context); + } + + private bool ShouldAlter => !Context.Schema.Equals(ExistingSchema); + private bool ShouldRecreate => HasIdentityChanged || HasColumnMetadataChanged; + + private ExistingSchema? ExistingSchema => _existingSchema; + + private bool HasIdentityChanged + { + get + { + foreach (var column in Context.Schema.Columns) + { + if (ExistingSchema.Columns.FirstOrDefault(f => string.Equals(f.Name, column.Name, StringComparison.OrdinalIgnoreCase)) is not ISchemaColumn existing) + return true; + + if (existing.IsIdentity != column.IsIdentity) + return true; + } + + foreach (var existing in ExistingSchema.Columns) + { + var column = Context.Schema.Columns.FirstOrDefault(f => string.Equals(f.Name, existing.Name, StringComparison.OrdinalIgnoreCase)); + + if (column is null && existing.IsIdentity) + return true; + else if (column is not null && column.IsIdentity != existing.IsIdentity) + return true; + } + + return false; + } + } + + private bool HasColumnMetadataChanged + { + get + { + foreach (var existing in ExistingSchema.Columns) + { + if (Context.Schema.Columns.FirstOrDefault(f => string.Equals(f.Name, existing.Name, StringComparison.OrdinalIgnoreCase)) is not ISchemaColumn column) + continue; + + if (column.DataType != existing.DataType + || column.MaxLength != existing.MaxLength + || column.IsNullable != existing.IsNullable + || column.IsVersion != existing.IsVersion + || column.Precision != existing.Precision + || column.Scale != existing.Scale + || column.DateKind != existing.DateKind + || column.BinaryKind != existing.BinaryKind + || column.DatePrecision != existing.DatePrecision) + return true; + } + + return false; + } + } + } +} diff --git a/Connected.Data/Schema/Sql/TableTransaction.cs b/Connected.Data/Schema/Sql/TableTransaction.cs new file mode 100644 index 0000000..12b4fbb --- /dev/null +++ b/Connected.Data/Schema/Sql/TableTransaction.cs @@ -0,0 +1,125 @@ +using System.Data; +using System.Text; +using Connected.Entities.Annotations; + +namespace Connected.Data.Schema.Sql; + +internal abstract class TableTransaction : SynchronizationTransaction +{ + protected static string CreateColumnCommandText(ISchemaColumn column) + { + var builder = new StringBuilder(); + + builder.AppendFormat($"{Escape(column.Name)} {CreateDataTypeMetaData(column)} "); + + if (column.IsIdentity) + builder.Append("IDENTITY(1,1) "); + + if (column.IsNullable) + builder.Append("NULL "); + else + builder.Append("NOT NULL "); + + return builder.ToString(); + } + + protected static string ResolveColumnLength(ISchemaColumn column) + { + if (column.MaxLength == -1) + return "MAX"; + + if (column.MaxLength > 0) + return column.MaxLength.ToString(); + + return column.DataType switch + { + DbType.AnsiString or DbType.String or DbType.AnsiStringFixedLength or DbType.StringFixedLength => 50.ToString(), + DbType.Binary => 128.ToString(), + DbType.Time or DbType.DateTime2 or DbType.DateTimeOffset => column.DatePrecision.ToString(), + DbType.VarNumeric => 8.ToString(), + DbType.Xml => "MAX", + DbType.Decimal => $"{column.Precision}, {column.Scale}", + _ => 50.ToString(), + }; + } + + protected static string CreateDataTypeMetaData(ISchemaColumn column) + { + return column.DataType switch + { + DbType.AnsiString => $"[varchar]({ResolveColumnLength(column)})", + DbType.Binary => column.IsVersion ? "[timestamp]" : column.BinaryKind == BinaryKind.Binary ? $"[binary]({ResolveColumnLength(column)})" : $"[varbinary]({ResolveColumnLength(column)})", + DbType.Byte => "[tinyint]", + DbType.Boolean => "[bit]", + DbType.Currency => "[money]", + DbType.Date => "[date]", + DbType.DateTime => column.DateKind == DateKind.SmallDateTime ? "[smalldatetime]" : "[datetime]", + DbType.Decimal => $"[decimal]({ResolveColumnLength(column)})", + DbType.Double => "[float]", + DbType.Guid => "[uniqueidentifier]", + DbType.Int16 => "[smallint]", + DbType.Int32 => "[int]", + DbType.Int64 => "[bigint]", + DbType.Object => $"[varbinary]({ResolveColumnLength(column)})", + DbType.SByte => "[smallint]", + DbType.Single => "[real]", + DbType.String => $"[nvarchar]({ResolveColumnLength(column)})", + DbType.Time => $"[time]({ResolveColumnLength(column)})", + DbType.UInt16 => "[int]", + DbType.UInt32 => "[bigint]", + DbType.UInt64 => "[float]", + DbType.VarNumeric => $"[numeric]({ResolveColumnLength(column)})", + DbType.AnsiStringFixedLength => $"[char]({ResolveColumnLength(column)})", + DbType.StringFixedLength => $"[nchar]({ResolveColumnLength(column)})", + DbType.Xml => "[xml]", + DbType.DateTime2 => $"[datetime2]({ResolveColumnLength(column)})", + DbType.DateTimeOffset => $"[datetimeoffset]({ResolveColumnLength(column)})", + _ => throw new NotSupportedException(), + }; + } + + protected static List ParseIndexes(ISchema schema) + { + var result = new List(); + + foreach (var column in schema.Columns) + { + if (column.IsPrimaryKey) + continue; + + if (column.IsIndex) + { + if (string.IsNullOrWhiteSpace(column.Index)) + { + var index = new IndexDescriptor + { + Unique = column.IsUnique, + }; + + index.Columns.Add(column.Name); + + result.Add(index); + } + else + { + var index = result.FirstOrDefault(f => string.Equals(f.Group, column.Index, StringComparison.OrdinalIgnoreCase)); + + if (index is null) + { + index = new IndexDescriptor + { + Group = column.Index, + Unique = column.IsUnique + }; + + result.Add(index); + } + + index.Columns.Add(column.Name); + } + } + } + + return result; + } +} diff --git a/Connected.Data/Schema/Sql/UniqueConstraintAdd.cs b/Connected.Data/Schema/Sql/UniqueConstraintAdd.cs new file mode 100644 index 0000000..5be35a5 --- /dev/null +++ b/Connected.Data/Schema/Sql/UniqueConstraintAdd.cs @@ -0,0 +1,42 @@ +using System.Text; + +namespace Connected.Data.Schema.Sql +{ + internal class UniqueConstraintAdd : TableTransaction + { + public UniqueConstraintAdd(IndexDescriptor index) + { + Index = index; + } + + private IndexDescriptor Index { get; } + + protected override async Task OnExecute() + { + await Context.Execute(CommandText); + } + + private string CommandText + { + get + { + var text = new StringBuilder(); + + text.AppendLine($"ALTER TABLE {Escape(Context.Schema.SchemaName(), Context.Schema.Name)}"); + text.AppendLine($"ADD CONSTRAINT [{Context.GenerateConstraintName(Context.Schema.SchemaName(), Context.Schema.Name, ConstraintNameType.Index)}] UNIQUE NONCLUSTERED ("); + var comma = string.Empty; + + foreach (var column in Index.Columns) + { + text.AppendLine($"{comma}{Escape(column)} ASC"); + + comma = ","; + } + + text.AppendLine($") ON {Escape(SchemaExtensions.FileGroup)}"); + + return text.ToString(); + } + } + } +} diff --git a/Connected.Data/Sharding/IShard.cs b/Connected.Data/Sharding/IShard.cs new file mode 100644 index 0000000..7bd33cb --- /dev/null +++ b/Connected.Data/Sharding/IShard.cs @@ -0,0 +1,10 @@ +using Connected.Data; + +namespace Connected.Data.Sharding; + +public interface IShard : IPrimaryKey +{ + int Node { get; init; } + string Entity { get; init; } + string EntityId { get; init; } +} diff --git a/Connected.Data/Sharding/IShardingMiddleware.cs b/Connected.Data/Sharding/IShardingMiddleware.cs new file mode 100644 index 0000000..f0cdd31 --- /dev/null +++ b/Connected.Data/Sharding/IShardingMiddleware.cs @@ -0,0 +1,11 @@ +using Connected.Entities.Storage; +using System.Collections.Immutable; + +namespace Connected.Data.Sharding; + +public interface IShardingMiddleware : IMiddleware +{ + bool SupportsEntity(Type entityType); + Task> ProvideNodes(IStorageOperation operation); + Task> ProvideNodes(Type entityType); +} diff --git a/Connected.Data/Sharding/IShardingNode.cs b/Connected.Data/Sharding/IShardingNode.cs new file mode 100644 index 0000000..5e477a3 --- /dev/null +++ b/Connected.Data/Sharding/IShardingNode.cs @@ -0,0 +1,9 @@ +namespace Connected.Data.Sharding; + +public interface IShardingNode : IPrimaryKey +{ + string Name { get; init; } + string ConnectionString { get; init; } + Status Status { get; init; } + string ConnectionType { get; init; } +} diff --git a/Connected.Data/Sharding/IShardingService.cs b/Connected.Data/Sharding/IShardingService.cs new file mode 100644 index 0000000..8378044 --- /dev/null +++ b/Connected.Data/Sharding/IShardingService.cs @@ -0,0 +1,10 @@ +using System.Collections.Immutable; + +namespace Connected.Data.Sharding +{ + public interface IShardingService + { + Task> Query(Type entity); + Task Select(Type entity, ShardPrimaryKeyArgs args); + } +} diff --git a/Connected.Data/Sharding/ShardPrimaryKeyArgs.cs b/Connected.Data/Sharding/ShardPrimaryKeyArgs.cs new file mode 100644 index 0000000..7cb80c8 --- /dev/null +++ b/Connected.Data/Sharding/ShardPrimaryKeyArgs.cs @@ -0,0 +1,11 @@ +using System.ComponentModel.DataAnnotations; +using Connected.ServiceModel; + +namespace Connected.Data.Sharding; + +public sealed class ShardPrimaryKeyArgs : Dto +{ + [Required] + [MaxLength(128)] + public string PrimaryKey { get; set; } = default!; +} diff --git a/Connected.Data/Sharding/ShardingService.cs b/Connected.Data/Sharding/ShardingService.cs new file mode 100644 index 0000000..b368a5e --- /dev/null +++ b/Connected.Data/Sharding/ShardingService.cs @@ -0,0 +1,17 @@ +using System.Collections.Immutable; + +namespace Connected.Data.Sharding +{ + internal class ShardingService : IShardingService + { + public Task> Query(Type entity) + { + throw new NotImplementedException(); + } + + public Task Select(Type entity, ShardPrimaryKeyArgs args) + { + throw new NotImplementedException(); + } + } +} diff --git a/Connected.Data/Sql/DatabaseCommand.cs b/Connected.Data/Sql/DatabaseCommand.cs new file mode 100644 index 0000000..40df605 --- /dev/null +++ b/Connected.Data/Sql/DatabaseCommand.cs @@ -0,0 +1,68 @@ +using Connected.Data.Storage; +using Connected.Entities.Storage; + +namespace Connected.Data.Sql; + +internal abstract class DatabaseCommand : IStorageCommand +{ + protected DatabaseCommand(IStorageOperation operation, IStorageConnection connection) + { + Connection = connection; + Operation = operation; + } + + protected bool IsDisposed { get; private set; } + public IStorageOperation Operation { get; } + public IStorageConnection? Connection { get; protected set; } + + protected virtual async ValueTask DisposeAsync(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + Connection = null; + + await OnDisposingAsync(); + } + + IsDisposed = true; + } + } + + protected virtual async ValueTask OnDisposingAsync() + { + await ValueTask.CompletedTask; + } + + protected virtual void OnDisposing() + { + } + + public async ValueTask DisposeAsync() + { + await DisposeAsync(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + Connection = null; + + OnDisposing(); + } + + IsDisposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Data/Sql/DatabaseReader.cs b/Connected.Data/Sql/DatabaseReader.cs new file mode 100644 index 0000000..2c84454 --- /dev/null +++ b/Connected.Data/Sql/DatabaseReader.cs @@ -0,0 +1,73 @@ +using Connected.Data.Storage; +using Connected.Entities.Storage; +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Data.Sql; + +internal class DatabaseReader : DatabaseCommand, IStorageReader +{ + public DatabaseReader(IStorageOperation operation, IStorageConnection connection) : base(operation, connection) + { + } + + public async Task?> Query() + { + if (Connection is null) + return default; + + try + { + var result = await Connection.Query(this); + + if (Connection.Behavior == StorageConnectionMode.Isolated) + await Connection.Commit(); + + return result; + } + finally + { + if (Connection.Behavior == StorageConnectionMode.Isolated) + { + await Connection.Close(); + await Connection.DisposeAsync(); + + Connection = null; + } + } + } + + public async Task Select() + { + try + { + if (Connection is null) + return default; + + var result = await Connection.Select(this); + + if (Connection.Behavior == StorageConnectionMode.Isolated) + await Connection.Commit(); + + return result; + } + finally + { + if (Connection.Behavior == StorageConnectionMode.Isolated) + { + await Connection.Close(); + await Connection.DisposeAsync(); + + Connection = null; + } + } + } + + public async Task OpenReader() + { + if (Connection is null) + return default; + + return await Connection.OpenReader(this); + } +} diff --git a/Connected.Data/Sql/DatabaseWriter.cs b/Connected.Data/Sql/DatabaseWriter.cs new file mode 100644 index 0000000..af150e5 --- /dev/null +++ b/Connected.Data/Sql/DatabaseWriter.cs @@ -0,0 +1,38 @@ +using Connected.Data.Storage; +using Connected.Entities.Storage; + +namespace Connected.Data.Sql; + +internal class DatabaseWriter : DatabaseCommand, IStorageWriter +{ + public DatabaseWriter(IStorageOperation operation, IStorageConnection connection) + : base(operation, connection) + { + } + + public async Task Execute() + { + if (Connection is null) + return -1; + + try + { + var recordsAffected = await Connection.Execute(this); + + if (Connection.Behavior == StorageConnectionMode.Isolated) + await Connection.Commit(); + + return recordsAffected; + } + finally + { + if (Connection.Behavior == StorageConnectionMode.Isolated) + { + await Connection.Close(); + await Connection.DisposeAsync(); + + Connection = null; + } + } + } +} diff --git a/Connected.Data/Sql/SqlDataConnection.cs b/Connected.Data/Sql/SqlDataConnection.cs new file mode 100644 index 0000000..a9cb762 --- /dev/null +++ b/Connected.Data/Sql/SqlDataConnection.cs @@ -0,0 +1,60 @@ +using Connected.Annotations; +using Connected.Data.Storage; +using Connected.ServiceModel; +using Microsoft.Data.SqlClient; +using System.Data; + +namespace Connected.Data.Sql; + +[ServiceRegistration(ServiceRegistrationMode.Auto, ServiceRegistrationScope.Transient)] +internal sealed class SqlDataConnection : DatabaseConnection +{ + public SqlDataConnection(ICancellationContext context) : base(context) + { + } + + protected override void SetupParameters(IStorageCommand command, IDbCommand cmd) + { + if (cmd.Parameters.Count > 0) + { + foreach (SqlParameter i in cmd.Parameters) + i.Value = DBNull.Value; + + return; + } + + if (command.Operation.Parameters is null) + return; + + foreach (var i in command.Operation.Parameters) + { + cmd.Parameters.Add(new SqlParameter + { + ParameterName = i.Name, + DbType = i.Type, + Direction = i.Direction + }); + } + } + + protected override object GetParameterValue(IDbCommand command, string parameterName) + { + if (command is SqlCommand cmd) + return cmd.Parameters[parameterName].Value; + + return null; + } + + protected override void SetParameterValue(IDbCommand command, string parameterName, object value) + { + if (command is SqlCommand cmd) + cmd.Parameters[parameterName].Value = value; + } + + protected override async Task OnCreateConnection() + { + await Task.CompletedTask; + + return new SqlConnection(ConnectionString); + } +} diff --git a/Connected.Data/Sql/SqlDataType.cs b/Connected.Data/Sql/SqlDataType.cs new file mode 100644 index 0000000..bb9c8c8 --- /dev/null +++ b/Connected.Data/Sql/SqlDataType.cs @@ -0,0 +1,18 @@ +using Connected.Expressions.Languages; +using System.Data; + +namespace Connected.Data.Sql; + +internal sealed class SqlDataType : DataType +{ + public SqlDataType(SqlDbType dbType, bool notNull, int length, short precision, short scale) + { + DbType = dbType; + NotNull = notNull; + Length = length; + Precision = precision; + Scale = scale; + } + + public SqlDbType DbType { get; } +} diff --git a/Connected.Data/Sql/SqlStorageProvider.cs b/Connected.Data/Sql/SqlStorageProvider.cs new file mode 100644 index 0000000..2870eb9 --- /dev/null +++ b/Connected.Data/Sql/SqlStorageProvider.cs @@ -0,0 +1,116 @@ +using Connected.Annotations; +using Connected.Data.Storage; +using Connected.Data.Update; +using Connected.Entities; +using Connected.Entities.Storage; +using Connected.Expressions; +using Connected.Expressions.Evaluation; +using Connected.Expressions.Query; +using Connected.Expressions.Translation; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Data.Sql; + +[Priority(1)] +internal sealed class SqlStorageProvider : QueryProvider, IStorageExecutor, IStorageMiddleware +{ + public SqlStorageProvider(IStorageProvider storage) + { + Storage = storage; + } + + private IStorageProvider Storage { get; } + + protected override object? OnExecute(Expression expression) + { + return CreateExecutionPlan(expression).Compile()(this); + } + + private static Expression> CreateExecutionPlan(Expression expression) + { + var lambda = expression as LambdaExpression; + + if (lambda is not null) + expression = lambda.Body; + + var context = new ExpressionCompilationContext(new TSqlLanguage()); + var translator = new Translator(context); + var translation = translator.Translate(expression); + var parameters = lambda?.Parameters; + var provider = Resolve(expression, parameters, typeof(IStorage<>)); + + if (provider is null) + { + var rootQueryable = Resolve(expression, parameters, typeof(IQueryable)); + + provider = Expression.Property(rootQueryable, typeof(IQueryable).GetTypeInfo().GetDeclaredProperty(nameof(IQueryable.Provider))); + } + + return ExecutionBuilder.Build(context, new TSqlLinguist(context, TSqlLanguage.Default, translator), translation, provider); + } + + /// + /// Find the expression of the specified type, either in the specified expression or parameters. + /// + private static Expression Resolve(Expression expression, IList parameters, Type type) + { + if (parameters is not null) + { + var found = parameters.FirstOrDefault(p => type.IsAssignableFrom(p.Type)); + + if (found is not null) + return found; + } + + return SubtreeResolver.Resolve(expression, type); + } + + public IEnumerable Execute(IStorageOperation operation) + where TResult : IEntity + { + /* + * Currently, only Shared connection is supported. Must find a way to pass + * behavior into the execution pipeline. + */ + var result = Storage.Open().Query(new StorageContextArgs(operation)); + + if (result.IsCompleted) + return result.Result; + + var r = result.GetAwaiter().GetResult(); + + return r; + } + + public bool SupportsEntity(Type entityType) + { + return entityType.IsAssignableTo(typeof(Entity)); + } + + public IStorageOperation CreateOperation(TEntity entity) + where TEntity : IEntity + { + var builder = new AggregatedCommandBuilder(); + + if (builder.Build(entity) is not StorageOperation operation) + throw new NullReferenceException(nameof(StorageOperation)); + + return operation; + } + + public IStorageReader OpenReader(IStorageOperation operation, IStorageConnection connection) + { + return new DatabaseReader(operation, connection); + } + + public IStorageWriter OpenWriter(IStorageOperation operation, IStorageConnection connection) + { + return new DatabaseWriter(operation, connection); + } + + public async Task Initialize() + { + await Task.CompletedTask; + } +} diff --git a/Connected.Data/Sql/SqlTypeSystem.cs b/Connected.Data/Sql/SqlTypeSystem.cs new file mode 100644 index 0000000..6f706d1 --- /dev/null +++ b/Connected.Data/Sql/SqlTypeSystem.cs @@ -0,0 +1,190 @@ +using Connected.Expressions.Languages; +using Connected.Expressions.TypeSystem; +using Connected.Interop; +using System.Data; +using System.Globalization; +using System.Reflection; +using System.Text; + +namespace Connected.Data.Sql; +internal sealed class SqlTypeSystem : QueryTypeSystem +{ + public static int StringDefaultSize => int.MaxValue; + public static int BinaryDefaultSize => int.MaxValue; + + public override DataType Parse(string typeDeclaration) + { + string[]? args = null; + string typeName; + string? remainder = null; + var openParen = typeDeclaration.IndexOf('('); + + if (openParen >= 0) + { + typeName = typeDeclaration[..openParen].Trim(); + + var closeParen = typeDeclaration.IndexOf(')', openParen); + + if (closeParen < openParen) + closeParen = typeDeclaration.Length; + + var argstr = typeDeclaration[(openParen + 1)..closeParen]; + + args = argstr.Split(','); + + remainder = typeDeclaration[(closeParen + 1)..]; + } + else + { + var space = typeDeclaration.IndexOf(' '); + + if (space >= 0) + { + typeName = typeDeclaration[..space]; + remainder = typeDeclaration[(space + 1)..].Trim(); + } + else + typeName = typeDeclaration; + } + + var isNotNull = (remainder is not null) && remainder.ToUpper().Contains("NOT NULL"); + + return ResolveDataType(typeName, args, isNotNull); + } + + public DataType ResolveDataType(string typeName, string[] args, bool isNotNull) + { + if (string.Equals(typeName, "rowversion", StringComparison.OrdinalIgnoreCase)) + typeName = "Timestamp"; + + if (string.Equals(typeName, "numeric", StringComparison.OrdinalIgnoreCase)) + typeName = "Decimal"; + + if (string.Equals(typeName, "sql_variant", StringComparison.OrdinalIgnoreCase)) + typeName = "Variant"; + + var dbType = ResolveSqlType(typeName); + var length = 0; + short precision = 0; + short scale = 0; + + switch (dbType) + { + case SqlDbType.Binary: + case SqlDbType.Char: + case SqlDbType.Image: + case SqlDbType.NChar: + case SqlDbType.NVarChar: + case SqlDbType.VarBinary: + case SqlDbType.VarChar: + length = args is null || !args.Any() ? 32 : string.Equals(args[0], "max", StringComparison.OrdinalIgnoreCase) ? int.MaxValue : int.Parse(args[0]); + break; + case SqlDbType.Money: + precision = args is null || !args.Any() ? (short)29 : short.Parse(args[0], NumberFormatInfo.InvariantInfo); + scale = args is null || args.Length < 2 ? (short)4 : short.Parse(args[1], NumberFormatInfo.InvariantInfo); + break; + case SqlDbType.Decimal: + precision = args is null || !args.Any() ? (short)29 : short.Parse(args[0], NumberFormatInfo.InvariantInfo); + scale = args is null || args.Length < 2 ? (short)0 : short.Parse(args[1], NumberFormatInfo.InvariantInfo); + break; + case SqlDbType.Float: + case SqlDbType.Real: + precision = args is null || !args.Any() ? (short)29 : short.Parse(args[0], NumberFormatInfo.InvariantInfo); + break; + } + + return NewType(dbType, isNotNull, length, precision, scale); + } + + private static DataType NewType(SqlDbType type, bool isNotNull, int length, short precision, short scale) + { + return new SqlDataType(type, isNotNull, length, precision, scale); + } + + public static SqlDbType ResolveSqlType(string typeName) + { + return (SqlDbType)Enum.Parse(typeof(SqlDbType), typeName, true); + } + + public override DataType ResolveColumnType(Type type) + { + var isNotNull = type.GetTypeInfo().IsValueType && !Nullables.IsNullableType(type); + type = Nullables.GetNonNullableType(type); + + switch (Interop.TypeSystem.GetTypeCode(type)) + { + case TypeCode.Boolean: + return NewType(SqlDbType.Bit, isNotNull, 0, 0, 0); + case TypeCode.SByte: + case TypeCode.Byte: + return NewType(SqlDbType.TinyInt, isNotNull, 0, 0, 0); + case TypeCode.Int16: + case TypeCode.UInt16: + return NewType(SqlDbType.SmallInt, isNotNull, 0, 0, 0); + case TypeCode.Int32: + case TypeCode.UInt32: + return NewType(SqlDbType.Int, isNotNull, 0, 0, 0); + case TypeCode.Int64: + case TypeCode.UInt64: + return NewType(SqlDbType.BigInt, isNotNull, 0, 0, 0); + case TypeCode.Single: + case TypeCode.Double: + return NewType(SqlDbType.Float, isNotNull, 0, 0, 0); + case TypeCode.String: + return NewType(SqlDbType.NVarChar, isNotNull, StringDefaultSize, 0, 0); + case TypeCode.Char: + return NewType(SqlDbType.NChar, isNotNull, 1, 0, 0); + case TypeCode.DateTime: + return NewType(SqlDbType.DateTime, isNotNull, 0, 0, 0); + case TypeCode.Decimal: + return NewType(SqlDbType.Decimal, isNotNull, 0, 29, 4); + default: + if (type == typeof(byte[])) + return NewType(SqlDbType.VarBinary, isNotNull, BinaryDefaultSize, 0, 0); + else if (type == typeof(Guid)) + return NewType(SqlDbType.UniqueIdentifier, isNotNull, 0, 0, 0); + else if (type == typeof(DateTimeOffset)) + return NewType(SqlDbType.DateTimeOffset, isNotNull, 0, 0, 0); + else if (type == typeof(TimeSpan)) + return NewType(SqlDbType.Time, isNotNull, 0, 0, 0); + else if (type.GetTypeInfo().IsEnum) + return NewType(SqlDbType.Int, isNotNull, 0, 0, 0); + else + throw new NotSupportedException(nameof(ResolveColumnType)); + } + } + + public static bool IsVariableLength(SqlDbType dbType) + { + return dbType switch + { + SqlDbType.Image or SqlDbType.NText or SqlDbType.NVarChar or SqlDbType.Text or SqlDbType.VarBinary or SqlDbType.VarChar or SqlDbType.Xml => true, + _ => false, + }; + } + + public override string Format(DataType type, bool suppressSize) + { + var sqlType = (SqlDataType)type; + var sb = new StringBuilder(); + + sb.Append(sqlType.DbType.ToString().ToUpper()); + + if (sqlType.Length > 0 && !suppressSize) + { + if (sqlType.Length == int.MaxValue) + sb.Append("(max)"); + else + sb.AppendFormat(NumberFormatInfo.InvariantInfo, "({0})", sqlType.Length); + } + else if (sqlType.Precision != 0) + { + if (sqlType.Scale != 0) + sb.AppendFormat(NumberFormatInfo.InvariantInfo, "({0},{1})", sqlType.Precision, sqlType.Scale); + else + sb.AppendFormat(NumberFormatInfo.InvariantInfo, "({0})", sqlType.Precision); + } + + return sb.ToString(); + } +} \ No newline at end of file diff --git a/Connected.Data/Sql/TSqlFormatter.cs b/Connected.Data/Sql/TSqlFormatter.cs new file mode 100644 index 0000000..972e35d --- /dev/null +++ b/Connected.Data/Sql/TSqlFormatter.cs @@ -0,0 +1,770 @@ +using Connected.Expressions; +using Connected.Expressions.Formatters; +using Connected.Expressions.Languages; +using Connected.Interop; +using System.Linq.Expressions; + +namespace Connected.Data.Sql; + +internal sealed class TSqlFormatter : SqlFormatter +{ + + public TSqlFormatter(ExpressionCompilationContext context, QueryLanguage? language) + : base(language) + { + Context = context; + } + public ExpressionCompilationContext Context { get; } + + public static new string Format(ExpressionCompilationContext context, Expression expression) + { + return Format(context, expression, new TSqlLanguage()); + } + public static string Format(ExpressionCompilationContext context, Expression expression, QueryLanguage language) + { + var formatter = new TSqlFormatter(context, language); + + formatter.Visit(expression); + + return formatter.ToString(); + } + protected override void WriteAggregateName(string aggregateName) + { + if (string.Equals(aggregateName, "LongCount", StringComparison.Ordinal)) + Write("COUNT_BIG"); + else + base.WriteAggregateName(aggregateName); + } + protected override Expression VisitMemberAccess(MemberExpression m) + { + if (m.Member.DeclaringType == typeof(string)) + { + switch (m.Member.Name) + { + case "Length": + Write("LEN("); + Visit(m.Expression); + Write(")"); + return m; + } + } + else if (m.Member.DeclaringType == typeof(DateTime) || m.Member.DeclaringType == typeof(DateTimeOffset)) + { + switch (m.Member.Name) + { + case "Day": + Write("DAY("); + Visit(m.Expression); + Write(")"); + return m; + case "Month": + Write("MONTH("); + Visit(m.Expression); + Write(")"); + return m; + case "Year": + Write("YEAR("); + Visit(m.Expression); + Write(")"); + return m; + case "Hour": + Write("DATEPART(hour, "); + Visit(m.Expression); + Write(")"); + return m; + case "Minute": + Write("DATEPART(minute, "); + Visit(m.Expression); + Write(")"); + return m; + case "Second": + Write("DATEPART(second, "); + Visit(m.Expression); + Write(")"); + return m; + case "Millisecond": + Write("DATEPART(millisecond, "); + Visit(m.Expression); + Write(")"); + return m; + case "DayOfWeek": + Write("(DATEPART(weekday, "); + Visit(m.Expression); + Write(") - 1)"); + return m; + case "DayOfYear": + Write("(DATEPART(dayofyear, "); + Visit(m.Expression); + Write(") - 1)"); + return m; + } + } + + return base.VisitMemberAccess(m); + } + + protected override Expression VisitMethodCall(MethodCallExpression m) + { + if (m.Method.DeclaringType == typeof(string)) + { + switch (m.Method.Name) + { + case "StartsWith": + Write("("); + Visit(m.Object); + Write(" LIKE "); + Visit(m.Arguments[0]); + Write(" + '%')"); + return m; + case "EndsWith": + Write("("); + Visit(m.Object); + Write(" LIKE '%' + "); + Visit(m.Arguments[0]); + Write(")"); + return m; + case "Contains": + Write("("); + Visit(m.Object); + Write(" LIKE '%' + "); + Visit(m.Arguments[0]); + Write(" + '%')"); + return m; + case "Concat": + var args = m.Arguments; + + if (args.Count == 1 && args[0].NodeType == ExpressionType.NewArrayInit) + args = ((NewArrayExpression)args[0]).Expressions; + for (var i = 0; i < args.Count; i++) + { + if (i > 0) + Write(" + "); + + Visit(args[i]); + } + return m; + case "IsNullOrEmpty": + Write("("); + Visit(m.Arguments[0]); + Write(" IS NULL OR "); + Visit(m.Arguments[0]); + Write(" = '')"); + return m; + case "ToUpper": + Write("UPPER("); + Visit(m.Object); + Write(")"); + return m; + case "ToLower": + Write("LOWER("); + Visit(m.Object); + Write(")"); + return m; + case "Replace": + Write("REPLACE("); + Visit(m.Object); + Write(", "); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + case "Substring": + Write("SUBSTRING("); + Visit(m.Object); + Write(", "); + Visit(m.Arguments[0]); + Write(" + 1, "); + + if (m.Arguments.Count == 2) + Visit(m.Arguments[1]); + else + Write("8000"); + + Write(")"); + return m; + case "Remove": + Write("STUFF("); + Visit(m.Object); + Write(", "); + Visit(m.Arguments[0]); + Write(" + 1, "); + + if (m.Arguments.Count == 2) + Visit(m.Arguments[1]); + else + Write("8000"); + + Write(", '')"); + return m; + case "IndexOf": + Write("(CHARINDEX("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Object); + + if (m.Arguments.Count == 2 && m.Arguments[1].Type == typeof(int)) + { + Write(", "); + Visit(m.Arguments[1]); + Write(" + 1"); + } + + Write(") - 1)"); + return m; + case "Trim": + Write("RTRIM(LTRIM("); + Visit(m.Object); + Write("))"); + return m; + } + } + else if (m.Method.DeclaringType == typeof(DateTime)) + { + switch (m.Method.Name) + { + case "op_Subtract": + if (m.Arguments[1].Type == typeof(DateTime)) + { + Write("DATEDIFF("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + } + break; + case "AddYears": + Write("DATEADD(YYYY,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddMonths": + Write("DATEADD(MM,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddDays": + Write("DATEADD(DAY,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddHours": + Write("DATEADD(HH,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddMinutes": + Write("DATEADD(MI,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddSeconds": + Write("DATEADD(SS,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + case "AddMilliseconds": + Write("DATEADD(MS,"); + Visit(m.Arguments[0]); + Write(","); + Visit(m.Object); + Write(")"); + return m; + } + } + else if (m.Method.DeclaringType == typeof(Decimal)) + { + switch (m.Method.Name) + { + case "Add": + case "Subtract": + case "Multiply": + case "Divide": + case "Remainder": + Write("("); + VisitValue(m.Arguments[0]); + Write(" "); + Write(GetOperator(m.Method.Name)); + Write(" "); + VisitValue(m.Arguments[1]); + Write(")"); + return m; + case "Negate": + Write("-"); + Visit(m.Arguments[0]); + Write(""); + return m; + case "Ceiling": + case "Floor": + Write(m.Method.Name.ToUpper()); + Write("("); + Visit(m.Arguments[0]); + Write(")"); + return m; + case "Round": + if (m.Arguments.Count == 1) + { + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", 0)"); + return m; + } + else if (m.Arguments.Count == 2 && m.Arguments[1].Type == typeof(int)) + { + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + } + break; + case "Truncate": + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", 0, 1)"); + return m; + } + } + else if (m.Method.DeclaringType == typeof(Math)) + { + switch (m.Method.Name) + { + case "Abs": + case "Acos": + case "Asin": + case "Atan": + case "Cos": + case "Exp": + case "Log10": + case "Sin": + case "Tan": + case "Sqrt": + case "Sign": + case "Ceiling": + case "Floor": + Write(m.Method.Name.ToUpper()); + Write("("); + Visit(m.Arguments[0]); + Write(")"); + return m; + case "Atan2": + Write("ATN2("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + case "Log": + if (m.Arguments.Count == 1) + goto case "Log10"; + + break; + case "Pow": + Write("POWER("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + case "Round": + if (m.Arguments.Count == 1) + { + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", 0)"); + return m; + } + else if (m.Arguments.Count == 2 && m.Arguments[1].Type == typeof(int)) + { + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", "); + Visit(m.Arguments[1]); + Write(")"); + return m; + } + break; + case "Truncate": + Write("ROUND("); + Visit(m.Arguments[0]); + Write(", 0, 1)"); + return m; + } + } + if (m.Method.Name == "ToString") + { + if (m.Object.Type != typeof(string)) + { + Write("CONVERT(NVARCHAR, "); + Visit(m.Object); + Write(")"); + } + else + Visit(m.Object); + + return m; + } + else if (!m.Method.IsStatic && string.Equals(m.Method.Name, "CompareTo", StringComparison.Ordinal) && m.Method.ReturnType == typeof(int) && m.Arguments.Count == 1) + { + Write("(CASE WHEN "); + Visit(m.Object); + Write(" = "); + Visit(m.Arguments[0]); + Write(" THEN 0 WHEN "); + Visit(m.Object); + Write(" < "); + Visit(m.Arguments[0]); + Write(" THEN -1 ELSE 1 END)"); + return m; + } + else if (m.Method.IsStatic && string.Equals(m.Method.Name, "Compare", StringComparison.Ordinal) && m.Method.ReturnType == typeof(int) && m.Arguments.Count == 2) + { + Write("(CASE WHEN "); + Visit(m.Arguments[0]); + Write(" = "); + Visit(m.Arguments[1]); + Write(" THEN 0 WHEN "); + Visit(m.Arguments[0]); + Write(" < "); + Visit(m.Arguments[1]); + Write(" THEN -1 ELSE 1 END)"); + return m; + } + else if (m.Method.DeclaringType == typeof(TypeComparer) && m.Method.IsStatic && string.Equals(m.Method.Name, nameof(TypeComparer.Compare), StringComparison.Ordinal) && m.Method.ReturnType == typeof(bool) && m.Arguments.Count == 2) + { + Visit(m.Arguments[0]); + Write(" = "); + Visit(m.Arguments[1]); + return m; + } + return base.VisitMethodCall(m); + } + + protected override NewExpression VisitNew(NewExpression nex) + { + if (nex.Constructor.DeclaringType == typeof(DateTime)) + { + if (nex.Arguments.Count == 3) + { + Write("Convert(DateTime, "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[0]); + Write(") + '/' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[1]); + Write(") + '/' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[2]); + Write("))"); + return nex; + } + else if (nex.Arguments.Count == 6) + { + Write("Convert(DateTime, "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[0]); + Write(") + '/' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[1]); + Write(") + '/' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[2]); + Write(") + ' ' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[3]); + Write(") + ':' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[4]); + Write(") + ':' + "); + Write("Convert(nvarchar, "); + Visit(nex.Arguments[5]); + Write("))"); + return nex; + } + } + + return base.VisitNew(nex); + } + + protected override Expression VisitBinary(BinaryExpression b) + { + if (b.NodeType == ExpressionType.Power) + { + Write("POWER("); + VisitValue(b.Left); + Write(", "); + VisitValue(b.Right); + Write(")"); + return b; + } + else if (b.NodeType == ExpressionType.Coalesce) + { + Write("COALESCE("); + VisitValue(b.Left); + Write(", "); + + var right = b.Right; + + while (right.NodeType == ExpressionType.Coalesce) + { + var rb = (BinaryExpression)right; + + VisitValue(rb.Left); + Write(", "); + + right = rb.Right; + } + + VisitValue(right); + Write(")"); + + return b; + } + else if (b.NodeType == ExpressionType.LeftShift) + { + Write("("); + VisitValue(b.Left); + Write(" * POWER(2, "); + VisitValue(b.Right); + Write("))"); + return b; + } + else if (b.NodeType == ExpressionType.RightShift) + { + Write("("); + VisitValue(b.Left); + Write(" / POWER(2, "); + VisitValue(b.Right); + Write("))"); + return b; + } + + return base.VisitBinary(b); + } + + protected override Expression VisitConstant(ConstantExpression c) + { + var parameter = Context.Parameters.FirstOrDefault(f => f.Value == c); + + if (parameter.Value is not null) + { + Write($"@{parameter.Key}"); + + return c; + } + + return base.VisitConstant(c); + } + + protected override Expression VisitValue(Expression expr) + { + if (IsPredicate(expr)) + { + Write("CASE WHEN ("); + Visit(expr); + Write(") THEN 1 ELSE 0 END"); + + return expr; + } + + return base.VisitValue(expr); + } + + protected override Expression VisitConditional(ConditionalExpression c) + { + if (IsPredicate(c.Test)) + { + Write("(CASE WHEN "); + VisitPredicate(c.Test); + Write(" THEN "); + VisitValue(c.IfTrue); + + var ifFalse = c.IfFalse; + + while (ifFalse is not null && ifFalse.NodeType == ExpressionType.Conditional) + { + var fc = (ConditionalExpression)ifFalse; + + Write(" WHEN "); + VisitPredicate(fc.Test); + Write(" THEN "); + VisitValue(fc.IfTrue); + + ifFalse = fc.IfFalse; + } + if (ifFalse is not null) + { + Write(" ELSE "); + VisitValue(ifFalse); + } + + Write(" END)"); + } + else + { + Write("(CASE "); + VisitValue(c.Test); + Write(" WHEN 0 THEN "); + VisitValue(c.IfFalse); + Write(" ELSE "); + VisitValue(c.IfTrue); + Write(" END)"); + } + + return c; + } + + protected override Expression VisitRowNumber(RowNumberExpression rowNumber) + { + Write("ROW_NUMBER() OVER("); + + if (rowNumber.OrderBy is not null && rowNumber.OrderBy.Any()) + { + Write("ORDER BY "); + + for (var i = 0; i < rowNumber.OrderBy.Count; i++) + { + var exp = rowNumber.OrderBy[i]; + + if (i > 0) + Write(", "); + + VisitValue(exp.Expression); + + if (exp.OrderType != OrderType.Ascending) + Write(" DESC"); + } + } + + Write(")"); + + return rowNumber; + } + + protected override Expression VisitIf(IfCommandExpression ifx) + { + if (!Language.AllowsMultipleCommands) + return base.VisitIf(ifx); + + Write("IF "); + Visit(ifx.Check); + WriteLine(Indentation.Same); + Write("BEGIN"); + WriteLine(Indentation.Inner); + VisitStatement(ifx.IfTrue); + WriteLine(Indentation.Outer); + + if (ifx.IfFalse is not null) + { + Write("END ELSE BEGIN"); + WriteLine(Indentation.Inner); + VisitStatement(ifx.IfFalse); + WriteLine(Indentation.Outer); + } + + Write("END"); + + return ifx; + } + + protected override Expression VisitBlock(Expressions.BlockExpression block) + { + if (!Language.AllowsMultipleCommands) + return base.VisitBlock(block); + + for (var i = 0; i < block.Commands.Count; i++) + { + if (i > 0) + { + WriteLine(Indentation.Same); + WriteLine(Indentation.Same); + } + + VisitStatement(block.Commands[i]); + } + + return block; + } + + protected override Expression VisitDeclaration(DeclarationExpression decl) + { + if (!Language.AllowsMultipleCommands) + return base.VisitDeclaration(decl); + + for (var i = 0; i < decl.Variables.Count; i++) + { + var v = decl.Variables[i]; + + if (i > 0) + WriteLine(Indentation.Same); + + Write("DECLARE @"); + Write(v.Name); + Write(" "); + Write(Language.TypeSystem.Format(v.DataType, false)); + } + + if (decl.Source is not null) + { + WriteLine(Indentation.Same); + Write("SELECT "); + + for (var i = 0; i < decl.Variables.Count; i++) + { + if (i > 0) + Write(", "); + + Write("@"); + Write(decl.Variables[i].Name); + Write(" = "); + Visit(decl.Source.Columns[i].Expression); + } + + if (decl.Source.From is not null) + { + WriteLine(Indentation.Same); + Write("FROM "); + VisitSource(decl.Source.From); + } + + if (decl.Source.Where is not null) + { + WriteLine(Indentation.Same); + Write("WHERE "); + Visit(decl.Source.Where); + } + } + else + { + for (var i = 0; i < decl.Variables.Count; i++) + { + var v = decl.Variables[i]; + + if (v.Expression is not null) + { + WriteLine(Indentation.Same); + Write("SET @"); + Write(v.Name); + Write(" = "); + Visit(v.Expression); + } + } + } + + return decl; + } +} diff --git a/Connected.Data/Sql/TSqlLanguage.cs b/Connected.Data/Sql/TSqlLanguage.cs new file mode 100644 index 0000000..97450df --- /dev/null +++ b/Connected.Data/Sql/TSqlLanguage.cs @@ -0,0 +1,53 @@ +using Connected.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; +using Connected.Expressions.TypeSystem; + +namespace Connected.Data.Sql; + +internal sealed class TSqlLanguage : QueryLanguage +{ + private static TSqlLanguage? _default; + + static TSqlLanguage() + { + SplitChars = new char[] { '.' }; + } + + public TSqlLanguage() + { + TypeSystem = new SqlTypeSystem(); + } + + public override QueryTypeSystem TypeSystem { get; } + private static char[] SplitChars { get; } + public override bool AllowsMultipleCommands => true; + public override bool AllowSubqueryInSelectWithoutFrom => true; + public override bool AllowDistinctInAggregates => true; + + public static TSqlLanguage Default + { + get + { + if (_default is null) + Interlocked.CompareExchange(ref _default, new TSqlLanguage(), null); + + return _default; + } + } + + public override string Quote(string name) + { + if (name.StartsWith("[") && name.EndsWith("]")) + return name; + else if (name.Contains('.')) + return $"[{string.Join("].[", name.Split(SplitChars, StringSplitOptions.RemoveEmptyEntries))}]"; + else + return $"[{name}]"; + } + + public override Linguist CreateLinguist(ExpressionCompilationContext context, Translator translator) + { + return new TSqlLinguist(context, this, translator); + } +} \ No newline at end of file diff --git a/Connected.Data/Sql/TSqlLinguist.cs b/Connected.Data/Sql/TSqlLinguist.cs new file mode 100644 index 0000000..e1fca67 --- /dev/null +++ b/Connected.Data/Sql/TSqlLinguist.cs @@ -0,0 +1,41 @@ +using Connected.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; +using Connected.Expressions.Translation.Rewriters; +using System.Linq.Expressions; + +namespace Connected.Data.Sql; + +internal sealed class TSqlLinguist : Linguist +{ + + public TSqlLinguist(ExpressionCompilationContext context, TSqlLanguage language, Translator translator) + : base(context, language, translator) + { + } + + public override Expression Translate(Expression expression) + { + /* + * fix up any order-by's + */ + expression = OrderByRewriter.Rewrite(Language, expression); + + expression = base.Translate(expression); + /* + * convert skip/take info into RowNumber pattern + */ + expression = SkipToRowNumberRewriter.Rewrite(Language, expression); + /* + * fix up any order-by's we may have changed + */ + expression = OrderByRewriter.Rewrite(Language, expression); + + return expression; + } + + public override string Format(Expression expression) + { + return TSqlFormatter.Format(Context, expression, Language); + } +} diff --git a/Connected.Data/Storage/ConnectionProvider.cs b/Connected.Data/Storage/ConnectionProvider.cs new file mode 100644 index 0000000..ffc50c4 --- /dev/null +++ b/Connected.Data/Storage/ConnectionProvider.cs @@ -0,0 +1,192 @@ +using Connected.Data.Schema; +using Connected.Data.Sharding; +using Connected.Entities.Storage; +using Connected.Middleware; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using System.Collections.Immutable; + +namespace Connected.Data.Storage; + +internal sealed class ConnectionProvider : IConnectionProvider, IAsyncDisposable, IDisposable +{ + private List _connections; + + public ConnectionProvider(IContext context, IMiddlewareService middleware, ITransactionContext transaction) + { + Context = context; + Middleware = middleware; + TransactionService = transaction; + TransactionService.StateChanged += OnTransactionStateChanged; + + _connections = new(); + } + + public IContext Context { get; } + public IMiddlewareService Middleware { get; } + private ITransactionContext TransactionService { get; } + private List Connections => _connections; + public StorageConnectionMode Mode { get; set; } = StorageConnectionMode.Isolated; + public async Task> Open(StorageContextArgs args) + { + /* + * Isolated transactions are supported only during active TransactionService state. + */ + if (TransactionService.State == MiddlewareTransactionState.Completed) + Mode = StorageConnectionMode.Isolated; + + return args is ISchemaSynchronizationContext context ? ResolveSingle(context) : await ResolveMultiple(args); + } + /// + /// This method is called if the supplied arguments already provided connection type on which they will perform operations. + /// + /// + /// This method is usually called when synchronizing entities because the synhronization process already knows what connections + /// should be used. + /// + /// + /// + /// + /// + private ImmutableList ResolveSingle(ISchemaSynchronizationContext args) + { + return new List { EnsureConnection(args.ConnectionType, args.ConnectionString) }.ToImmutableList(); + } + + private async Task> ResolveMultiple(StorageContextArgs args) + { + var connectionMiddleware = await ResolveConnectionMiddleware(); + /* + * Check if sharding is supported on the entity + */ + var middleware = await Middleware.Query(); + IShardingMiddleware? shardingMiddleware = null; + + foreach (var m in middleware) + { + if (m.SupportsEntity(typeof(TEntity))) + { + shardingMiddleware = m; + break; + } + } + + var result = new List + { + /* + * Default connection is always used regardless of sharding support + */ + EnsureConnection(connectionMiddleware.ConnectionType, connectionMiddleware.DefaultConnectionString) + }; + + if (shardingMiddleware is not null) + { + foreach (var node in await shardingMiddleware.ProvideNodes(args.Operation)) + { + /* + * Sharding is only supported on connection of the same type. + */ + if (!string.Equals(node.ConnectionType, connectionMiddleware.ConnectionType.FullName, StringComparison.Ordinal)) + throw new ArgumentException("Sharding connection types mismatch ({connectionType})", node.ConnectionType); + + if (Type.GetType(node.ConnectionType) is not Type connectionType) + throw new NullReferenceException(node.ConnectionType); + + result.Add(EnsureConnection(connectionType, node.ConnectionString)); + } + } + + return result.ToImmutableList(); + } + + private IStorageConnection EnsureConnection(Type connectionType, string connectionString) + { + if (Mode == StorageConnectionMode.Shared + && Connections.FirstOrDefault(f => f.GetType() == connectionType + && string.Equals(f.ConnectionString, connectionString, StringComparison.OrdinalIgnoreCase)) is IStorageConnection existing) + { + return existing; + } + else + return CreateConnection(connectionType, connectionString, Mode); + } + + private IStorageConnection CreateConnection(Type connectionType, string connectionString, StorageConnectionMode behavior) + { + if (Context.GetService(connectionType) is not IStorageConnection newConnection) + throw new NullReferenceException(connectionType.Name); + + newConnection.Initialize(new StorageConnectionArgs(connectionString, behavior)); + + if (behavior == StorageConnectionMode.Shared) + Connections.Add(newConnection); + + return newConnection; + } + + private async Task ResolveConnectionMiddleware() + { + var middleware = await Middleware.Query(); + + foreach (var m in middleware) + { + if (await m.IsEntitySupported(typeof(TEntity))) + return m; + } + + throw new NullReferenceException(nameof(ResolveConnectionMiddleware)); + } + private async void OnTransactionStateChanged(object? sender, EventArgs e) + { + if (TransactionService.State == MiddlewareTransactionState.Committing) + await Commit(); + else if (TransactionService.State == MiddlewareTransactionState.Reverting) + await Rollback(); + } + + private async Task Commit() + { + foreach (var connection in Connections) + await connection.Commit(); + } + + private async Task Rollback() + { + foreach (var connection in Connections) + await connection.Rollback(); + } + public async ValueTask DisposeAsync() + { + if (_connections is not null) + { + foreach (var connection in _connections) + await connection.DisposeAsync().ConfigureAwait(false); + + _connections = null; + } + + Dispose(false); + + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + if (_connections is not null) + { + foreach (var connection in _connections) + connection.Dispose(); + + _connections = null; + } + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Data/Storage/DatabaseConnection.cs b/Connected.Data/Storage/DatabaseConnection.cs new file mode 100644 index 0000000..a382f90 --- /dev/null +++ b/Connected.Data/Storage/DatabaseConnection.cs @@ -0,0 +1,440 @@ +using Connected.Entities.Storage; +using Connected.Middleware; +using Connected.ServiceModel; +using Connected.Threading; +using System.Collections.Concurrent; +using System.Collections.Immutable; +using System.Data; +using System.Data.Common; + +namespace Connected.Data.Storage; + +public abstract class DatabaseConnection : MiddlewareComponent, IStorageConnection +{ + private readonly AsyncLocker _lock = new(); + private IDbConnection _connection; + private ConcurrentDictionary _commands = null; + + protected DatabaseConnection(ICancellationContext context) + { + Context = context; + + Commands = new(); + } + + protected ICancellationContext Context { get; } + private IDbTransaction Transaction { get; set; } + public StorageConnectionMode Behavior { get; set; } + public string ConnectionString { get; private set; } + + public async Task Initialize(StorageConnectionArgs args) + { + ConnectionString = args.ConnectionString; + Behavior = args.Behavior; + + await OnInitialize(); + } + + protected virtual async Task OnInitialize() + { + await Task.CompletedTask; + } + + protected abstract Task OnCreateConnection(); + + private async Task GetConnection() + { + if (_connection is null) + { + await _lock.LockAsync(1, async () => + { + _connection ??= await OnCreateConnection(); + }); + } + + return _connection; + } + + private ConcurrentDictionary Commands { get; } + + public async Task Commit() + { + if (Transaction is null || Transaction.Connection is null) + return; + + await _lock.LockAsync(2, async () => + { + if (Transaction is null || Transaction.Connection is null) + return; + + if (Transaction is DbTransaction db) + await db.CommitAsync(Context is null ? CancellationToken.None : Context.CancellationToken); + else + Transaction.Commit(); + + Transaction.Dispose(); + Transaction = null; + }); + + await Task.CompletedTask; + } + + public async Task Rollback() + { + if (Transaction is null || Transaction.Connection is null) + return; + + await _lock.LockAsync(3, async () => + { + if (Transaction is null || Transaction.Connection is null) + return; + + try + { + if (Transaction is DbTransaction db) + await db.RollbackAsync(Context is null ? CancellationToken.None : Context.CancellationToken); + else + Transaction.Rollback(); + } + catch { } + }); + + await Task.CompletedTask; + } + + private async Task Open() + { + var connection = await GetConnection(); + + if (connection.State == ConnectionState.Open) + return; + + await _lock.LockAsync(4, async () => + { + if (connection.State != ConnectionState.Closed) + return; + + if (connection is DbConnection db) + await db.OpenAsync(Context is null ? CancellationToken.None : Context.CancellationToken); + else + connection.Open(); + }); + + await Task.CompletedTask; + } + + public async Task Close() + { + if (_connection is null) + return; + + if (_connection is not null && _connection.State == ConnectionState.Open) + { + await _lock.LockAsync(5, async () => + { + if (_connection is not null && _connection.State == ConnectionState.Open) + { + if (Transaction is not null && Transaction.Connection is not null) + { + try + { + if (Transaction is DbTransaction db) + await db.RollbackAsync(Context is null ? CancellationToken.None : Context.CancellationToken); + else + Transaction.Rollback(); + } + catch { } + + } + + if (_connection is DbConnection dbc) + await dbc.CloseAsync(); + else + _connection.Close(); + } + }); + } + + await Task.CompletedTask; + } + + public async Task Execute(IStorageCommand command) + { + await EnsureOpen(true); + + var com = await ResolveCommand(command); + + SetupParameters(command, com); + + if (command.Operation.Parameters is not null) + { + foreach (var i in command.Operation.Parameters) + SetParameterValue(com, i.Name, i.Value); + } + + var recordsAffected = await OnExecute(command, com); + + if (command.Operation.Parameters is not null) + { + foreach (var i in command.Operation.Parameters) + { + if (i.Direction == ParameterDirection.Output) + i.Value = GetParameterValue(com, i.Name); + } + } + + return recordsAffected; + + } + protected virtual void SetParameterValue(IDbCommand command, string parameterName, object value) + { + + } + + protected virtual object? GetParameterValue(IDbCommand command, string parameterName) + { + return default; + } + + protected virtual void SetupParameters(IStorageCommand command, IDbCommand cmd) + { + } + + protected virtual async Task OnExecute(IStorageCommand command, IDbCommand cmd) + { + if (cmd is DbCommand dbCommand) + return await dbCommand.ExecuteNonQueryAsync(Context is null ? CancellationToken.None : Context.CancellationToken); + else + return cmd.ExecuteNonQuery(); + } + + public virtual async Task> Query(IStorageCommand command) + { + await EnsureOpen(false); + + var com = await ResolveCommand(command); + + IDataReader rdr = null; + + try + { + SetupParameters(command, com); + + if (command.Operation.Parameters is not null) + { + foreach (var i in command.Operation.Parameters) + SetParameterValue(com, i.Name, i.Value); + } + + rdr = com is DbCommand db ? await db.ExecuteReaderAsync(Context is null ? CancellationToken.None : Context.CancellationToken) : com.ExecuteReader(); + var result = new List(); + var mappings = new FieldMappings(rdr); + + while (rdr.Read()) + result.Add(mappings.CreateInstance(rdr)); + + return result.ToImmutableList(); + } + finally + { + if (rdr != null && !rdr.IsClosed) + { + if (rdr is DbDataReader db) + await db.CloseAsync(); + else + rdr.Close(); + } + } + } + + public virtual async Task Select(IStorageCommand command) + { + await EnsureOpen(false); + + var com = await ResolveCommand(command); + + IDataReader rdr = null; + + try + { + SetupParameters(command, com); + + if (command.Operation.Parameters is not null) + { + foreach (var i in command.Operation.Parameters) + SetParameterValue(com, i.Name, i.Value); + } + + rdr = com.ExecuteReader(CommandBehavior.SingleRow); + var mappings = new FieldMappings(rdr); + + if (rdr.Read()) + return mappings.CreateInstance(rdr); + + return default; + } + finally + { + if (rdr != null && !rdr.IsClosed) + { + if (rdr is DbDataReader db) + await db.CloseAsync(); + else + rdr.Close(); + } + } + } + + public virtual async Task OpenReader(IStorageCommand command) + { + await EnsureOpen(false); + + var com = await ResolveCommand(command); + + SetupParameters(command, com); + + if (command.Operation.Parameters is not null) + { + foreach (var i in command.Operation.Parameters) + SetParameterValue(com, i.Name, i.Value); + } + + return com.ExecuteReader(); + } + + protected virtual async Task ResolveCommand(IStorageCommand command) + { + if (Commands.TryGetValue(command, out IDbCommand? existing)) + return existing; + + if (Commands.TryGetValue(command, out IDbCommand? existing2)) + return existing2; + + return await _lock.LockAsync(6, async () => + { + var connection = await GetConnection(); + + var r = connection.CreateCommand(); + + r.CommandText = command.Operation.CommandText; + r.CommandType = command.Operation.CommandType; + r.CommandTimeout = command.Operation.CommandTimeout; + + if (Transaction is not null) + r.Transaction = Transaction; + + Commands.TryAdd(command, r); + + return r; + }); + } + + private async Task EnsureOpen(bool createTransaction) + { + var connection = await GetConnection(); + + if (connection is null || connection.State == ConnectionState.Open) + return; + + await _lock.LockAsync(7, async () => + { + await Open(); + + if (createTransaction && Transaction is null) + { + Transaction = connection is DbConnection dbc + ? await dbc.BeginTransactionAsync(Context is null ? CancellationToken.None : Context.CancellationToken) + : connection.BeginTransaction(IsolationLevel.ReadCommitted); + } + }); + } + public void Dispose() + { + OnDispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void OnDispose(bool disposing) + { + AsyncUtils.RunSync(() => Close()); + + if (_commands is not null) + { + foreach (var command in _commands) + command.Value.Dispose(); + + _commands = null; + } + + if (Transaction is not null) + { + try + { + AsyncUtils.RunSync(Rollback); + + Transaction.Dispose(); + } + catch { } + + Transaction = null; + } + + if (_connection is not null) + { + _connection.Dispose(); + _connection = null; + } + } + public async ValueTask DisposeAsync() + { + await OnDisposeAsyncCore().ConfigureAwait(false); + + OnDispose(false); + GC.SuppressFinalize(this); + } + + protected virtual async ValueTask OnDisposeAsyncCore() + { + await Close().ConfigureAwait(false); + + if (_commands is not null) + { + foreach (var command in _commands) + { + if (command.Value is DbCommand db) + await db.DisposeAsync().ConfigureAwait(false); + else + command.Value.Dispose(); + } + + _commands = null; + } + + if (Transaction is not null) + { + try + { + //No way to check if possible + await Rollback().ConfigureAwait(false); + + if (Transaction is DbTransaction dbt) + await dbt.DisposeAsync().ConfigureAwait(false); + else + Transaction.Dispose(); + } + catch { } + + Transaction = null; + } + + if (_connection is not null) + { + if (_connection is DbConnection dbc) + await dbc.DisposeAsync().ConfigureAwait(false); + else + _connection.Dispose(); + + _connection = null; + } + } +} \ No newline at end of file diff --git a/Connected.Data/Storage/EntityStorage.cs b/Connected.Data/Storage/EntityStorage.cs new file mode 100644 index 0000000..a83d8ea --- /dev/null +++ b/Connected.Data/Storage/EntityStorage.cs @@ -0,0 +1,520 @@ +using System.Collections; +using System.Collections.Immutable; +using System.Data; +using System.Linq.Expressions; +using Connected.Data.DataProtection; +using Connected.Entities; +using Connected.Entities.Storage; +using Connected.Middleware; +using Connected.ServiceModel.Transactions; +using Connected.Validation; + +namespace Connected.Data.Storage; +/// +/// Provides read and write operations on the supported storage providers. +/// +/// >The type of the entitiy on which operations are performed. +internal class EntityStorage : IAsyncEnumerable, IStorage + where TEntity : IEntity +{ + private IQueryProvider _provider; + /// + /// Creates a new instance. + /// + /// Middleware for protecting transactions and data access. + /// Middleware for recurring validation. + /// Middleware for providing storage connections. + /// Saga Transactions orchestration. + /// The storage connection behavior. + public EntityStorage(IEntityProtectionService dataProtection, IMiddlewareService middleware, IConnectionProvider connections, + ITransactionContext transactions) + { + EntityProtection = dataProtection; + Middleware = middleware; + Connections = connections; + Transactions = transactions; + Expression = Expression.Constant(this); + } + /// + /// The middleware used when performing entity operations. + /// + private IStorageMiddleware StorageMiddleware + { + get + { + if (Provider is not IStorageMiddleware result) + throw new InvalidCastException(nameof(IStorageMiddleware)); + + return result; + } + } + /// + /// The expression used for retrieving entities. + /// + public Expression Expression { get; } + /// + /// The entity type for which operations are performed. + /// + public Type ElementType => typeof(TEntity); + /// + /// The provider used when querying entities. It is based on the . + /// + public IQueryProvider Provider => _provider; + /// + /// Middleware used for protecting data access and manipulation. + /// + private IEntityProtectionService EntityProtection { get; } + /// + /// Middleware used for validation in concurrency transactions. + /// + private IMiddlewareService Middleware { get; } + /// + /// Middleware for retrieving storage connections. + /// + private IConnectionProvider Connections { get; } + /// + /// Middleware providing saga transactions orchestration. + /// + private ITransactionContext Transactions { get; } + /// + /// Gets enumerator for entities retrieved via . + /// + /// Enumerator containing entities. + public IEnumerator GetEnumerator() + { + var result = Provider.Execute(Expression); + /* + * Make sure we always return non nullable value. + */ + if (result is null) + return new List().GetEnumerator(); + + return ((IEnumerable)result).GetEnumerator(); + } + /// + /// Gets enumerator for entities retrieved via . + /// + /// Enumerator containing entities. + IEnumerator IEnumerable.GetEnumerator() + { + var result = Provider.Execute(Expression); + /* + * Make sure we always return non nullable value. + */ + if (result is null) + return new List().GetEnumerator(); + + return ((IEnumerable)result).GetEnumerator(); + } + /// + /// Gets enumerator for asynchronoues entity retrieval via . + /// + /// Token that enables operation to be cancelled. + /// Asynchronous enumerator containing entities. + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + var result = Provider.Execute(Expression); + + if (result is IEnumerable en) + { + var enumerator = en.GetEnumerator(); + + while (enumerator.MoveNext()) + { + await Task.CompletedTask; + + yield return (TEntity)enumerator.Current; + } + } + } + + public override string ToString() + { + if (Expression.NodeType == ExpressionType.Constant && ((ConstantExpression)Expression).Value == this) + return "Query(" + typeof(TEntity) + ")"; + else + return Expression.ToString(); + } + /// + /// Performs the update on the specified entity. + /// + /// The entity to be updated. + /// Entity with an id if insert has been executed, the same entity otherwise. + /// Thrown if the storage statement could no be created. + public async Task Update(TEntity? entity) + { + if (entity is null) + return entity; + + await EntityProtection.Invoke(new EntityProtection.EntityProtectionArgs(entity, entity.State)); + + var operation = StorageMiddleware.CreateOperation(entity); + + await Execute(new StorageContextArgs(operation)); + + return entity; + } + /// + /// Performs the update on the specified entity with optional concurrency callback support. + /// + /// The type of the arguments used to update the entity. + /// The entity to update. + /// The arguments that supplied updated values. + /// The retry delegate for preparing new update. + /// The entity with the newly inserted id if insert was performed, the same entity otherwise. + public async Task Update(TEntity? entity, TArgs args, Func>? concurrencyRetrying) + where TArgs : IDto + { + return await Update(entity, args, concurrencyRetrying, null); + } + /// + /// Updates the entity to the underlying storage with concurrency check. + /// + /// The type of the arguments used to update the entity. + /// The entity to update. + /// The arguments that supplied updated values. + /// The retry delegate for preparing new update. + /// An optional merge callback if default merge is not sufficient. + /// The entity with the newly inserted id if insert was performed, the same entity otherwise. + public async Task Update(TEntity? entity, TArgs args, Func>? concurrencyRetrying, Func>? merging) + where TArgs : IDto + { + if (entity is null) + return entity; + + DBConcurrencyException? lastException = null; + /* + * Merge the updating entity with the supplied arguments. If callback is provided it is used instead of the default merge. + */ + var currentEntity = merging is null ? entity.Merge(args, entity.State) : await merging(entity); + /* + * There will be 3 retries. If none is succedded an exception will be thrown. + */ + for (var i = 0; i < 3; i++) + { + try + { + /* + * Perform the update. Note that provider should check for concurrency only if + * the entity is updating. Concurrency is not used for inserting and deleting operations. + */ + await Update(currentEntity); + /* + * Provider will merge the updating entity with a new id if the operation is Insert. For updating and + * deleting operations the same entity is returned. + */ + return currentEntity; + } + catch (DBConcurrencyException ex) + { + /* + * Concurrency exception occurred. If the callback is not passed we return immediatelly. + */ + if (concurrencyRetrying is null) + throw; + + lastException = ex; + /* + * Wait a small amount of time if the system is currently under heavy load to increase the probabillity + * of successful update. + */ + await Task.Delay(i * i * 50); + /* + * We must perform validation again since the state of the entities has possibly changed. Note that + * only middleware validation is performed not the argument (attribute based). + */ + if (await Middleware.Query>() is ImmutableList> items) + { + foreach (var item in items) + await item.Validate(args); + } + /* + * If validation succedded invoke callback which usually refreshes the cache which causes the entity to be + * reloaded from the data source. + */ + currentEntity = await concurrencyRetrying(); + /* + * Entity must be supplied and the merge is performed again. + */ + if (currentEntity is not null) + currentEntity = merging is null ? currentEntity.Merge(args, entity.State) : await merging(currentEntity); + else + throw new NullReferenceException(nameof(entity)); + } + } + /* + * This is not good. We couldn't update the entity after 3 retries. The system is most probably either under heavy load and + * the entity is updating very frequently. + */ + if (lastException is not null) + throw lastException; + + return default; + } + /// + /// Executes storage operation agains one or mode storages. + /// + /// The arguments containing data about operation to be performed. + /// The number of records affected in the physical storage. + /// If concurrency is supported on the operation and + /// no records have been affected and the actual operation is UPDATE this exception is thrown. + public async Task Execute(StorageContextArgs args) + { + await using var writer = await OpenWriter(args); + /* + * Execute operation against the storage. This method should return the number of records affected. + */ + var recordsAffected = await writer.Execute(); + /* + * It is not necessary that concurrency is actualy considered. Concurrency should be disabled + * if the operation is not UPDATE or the entity does not supports it (does not have an Etag or similar property). + */ + if (recordsAffected == 0 && args.Operation.Concurrency == DataConcurrencyMode.Enabled) + throw new DBConcurrencyException($"{SR.ErrDataConcurrency} ({typeof(Entity).Name})"); + /* + * Bind storage parameters with operation parameters. + */ + ReturnValueBinder.Bind(writer, args.Operation); + + return recordsAffected; + } + /// + /// Opens one or more readers for the specified entity. + /// + /// + /// If the entity does not support sharding, only one reader is returned. If arguments require more + /// than one shard to be read this method will return one for every shard. + /// + /// The arguments containing data about operation to be performed. + /// One or more . + private async Task>> OpenEntityReaders(StorageContextArgs args) + { + /* + * Connection middleware will return one connection for every shard. If sharding is not supported only + * one connection will be returned. + */ + var connections = await Connections.Open(args); + var result = new List>(); + + foreach (var connection in connections) + result.Add(OpenReader(args.Operation, connection)); + + return result.ToImmutableList(); + } + /// + /// Opens one or more readers for the specified entity. + /// + /// The arguments containing data about operation to be performed. + /// One or more . + public async Task> OpenReaders(StorageContextArgs args) + { + /* + * Connection middleware will return one connection for every shard. If sharding is not supported only + * one connection will be returned. + */ + var connections = await Connections.Open(args); + var result = new List(); + + foreach (var connection in connections) + { + /* + * Temporarly create a full database reader. We won't actually need it but it is + * the only way to get to the actual IDataReader. + */ + await using var r = StorageMiddleware.OpenReader(args.Operation, connection); + /* + * Now open reader and add it to the result. + */ + result.Add(await r.OpenReader()); + } + + return result.ToImmutableList(); + } + /// + /// Opens the on the underlying connection. + /// + /// The operation to be performed on the data reader. + /// The connection to be used when opening the reader. + /// The . + private IStorageReader OpenReader(IStorageOperation operation, IStorageConnection connection) + { + return StorageMiddleware.OpenReader(operation, connection); + } + /// + /// Opens the on the urderlying connection. + /// + /// The arguments containing operation to be performed. + /// The . + private async Task OpenWriter(StorageContextArgs args) + { + var connections = await Connections.Open(args); + /* + * Only one connection should be returned when performing transactions + * on the single entity. + */ + if (connections.Count != 1) + throw new InvalidOperationException("Only one connection expected."); + + return OpenWriter(args.Operation, connections[0]); + } + /// + /// Opens the on the urderlying connection. + /// + /// The operation to be performed. + /// The connection to be used on the writer. + /// The . + private IStorageWriter OpenWriter(IStorageOperation operation, IStorageConnection connection) + { + /* + * Signal transaction orchestration that we are going to use transactions. + */ + Transactions.IsDirty = true; + + return StorageMiddleware.OpenWriter(operation, connection); + } + /// + /// Performs a query for the specified operation. + /// + /// Arguments containing data about operation to be performed. + /// A List of entities that were returned from the storage. + public async Task?> Query(StorageContextArgs args) + { + var readers = await OpenEntityReaders(args); + /* + * In a sharding model it is possible that more than one reader will be returned since + * data could reside in more than one shard, for example: + * we have a projects, each having its work items in it own shard. It's fine to query + * work items for the project since they are definitely in the same shard. But whyt about + * querying work items for the specific user. If the user has access to the more then one + * project it is very likely that work items are in more than one shard. + */ + if (readers.Count == 1) + { + var result = await readers[0].Query(); + + await readers[0].DisposeAsync(); + + return result; + } + else + { + /* + * It's a sharding scenario + */ + var results = new List(); + var tasks = new List(); + + foreach (var reader in readers) + { + tasks.Add(Task.Run(async () => + { + if (await reader.Query() is ImmutableList r && !r.IsEmpty) + { + lock (results) + results.AddRange(r); + } + })); + } + + await Task.WhenAll(tasks); + + /* + * Need to manually dispose all readers. + */ + foreach (var reader in readers) + await reader.DisposeAsync(); + + return results.ToImmutableList(); + } + } + /// + /// Performs a single entity select for the specified operation. + /// + /// Arguments containing data about operation to be performed. + /// An entity if found, null otherwise. + public async Task Select(StorageContextArgs args) + { + var readers = await OpenEntityReaders(args); + /* + * In a sharding model, it is possible that a middleware won't know + * exactly in which shard the record resides. This is not ideal but very much + * possible scenario. This is why we will perform a call on all available + * readers and then, if more then one record returned, selects only the first one. + * ------------------------------------------------------------------------- + * Q: should we throw an exception if more than one record is found? + * ------------------------------------------------------------------------- + */ + TEntity? result = default; + + if (readers.Count == 1) + result = await readers[0].Select(); + else + { + var results = new List(); + var tasks = new List(); + + foreach (var reader in readers) + { + tasks.Add(Task.Run(async () => + { + if (await reader.Select() is TEntity r) + { + lock (results) + results.Add(r); + } + })); + } + + await Task.WhenAll(tasks); + + if (results.Any()) + result = results[0]; + } + /* + * Need to manually dispose all readers. + */ + foreach (var reader in readers) + await reader.DisposeAsync(); + + return result; + } + /// + /// Resolved provider used based on the entity type. + /// + /// + /// + private async Task ResolveProvider() + { + /* + * We need to resolve provider based on an entity type. At least one + * provider must respond to the entity type. On the other hand, only + * one provider should handle entity type. This means sharding is not + * supported on nodes with different connection types. + */ + var middlewares = await Middleware.Query(); + + if (!middlewares.Any()) + throw new NullReferenceException(nameof(IStorageMiddleware)); + + foreach (var middleware in middlewares) + { + /* + * The first middleware supporting the entity wins. + */ + if (middleware.SupportsEntity(ElementType)) + { + _provider = middleware; + + break; + } + } + + if (_provider is null) + throw new NullReferenceException($"{nameof(IStorageMiddleware)} -> {ElementType.Name}"); + } + + public async Task Initialize() + { + await ResolveProvider(); + } +} diff --git a/Connected.Data/Storage/IConnectionProvider.cs b/Connected.Data/Storage/IConnectionProvider.cs new file mode 100644 index 0000000..eaa129b --- /dev/null +++ b/Connected.Data/Storage/IConnectionProvider.cs @@ -0,0 +1,27 @@ +using Connected.Data.Sharding; +using Connected.Entities.Storage; +using System.Collections.Immutable; + +namespace Connected.Data.Storage; +/// +/// This middleware provides one or more connection for the specified arguments. +/// +/// +/// If entity supports sharding (provided by ) it is possible that +/// more than one connection will be returned. For the transactions only one connection is tipically returned +/// since only one entity at a time is usually performed. For query operations the scenario could be more complex +/// because data could reside in more than one shard. In that case one connection for each shhard will be returned. +/// +public interface IConnectionProvider +{ + /// + /// Opens one or more connections for the specified arguments. + /// + /// The type of the entity on which storage operations will be performed. + /// The arguments describing what operation is about to be performed. + /// One or more storage connections. One connection if entity does not supports sharding, more if + /// it supports it and the operation requires data from more than one shard. + Task> Open(StorageContextArgs args); + + StorageConnectionMode Mode { get; set; } +} diff --git a/Connected.Data/Storage/IStorageCommand.cs b/Connected.Data/Storage/IStorageCommand.cs new file mode 100644 index 0000000..c3c088e --- /dev/null +++ b/Connected.Data/Storage/IStorageCommand.cs @@ -0,0 +1,10 @@ +using Connected.Entities.Storage; + +namespace Connected.Data.Storage +{ + public interface IStorageCommand : IDisposable, IAsyncDisposable + { + IStorageOperation Operation { get; } + IStorageConnection? Connection { get; } + } +} diff --git a/Connected.Data/Storage/IStorageConnection.cs b/Connected.Data/Storage/IStorageConnection.cs new file mode 100644 index 0000000..728d9e7 --- /dev/null +++ b/Connected.Data/Storage/IStorageConnection.cs @@ -0,0 +1,25 @@ +using Connected.Entities.Storage; +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Data.Storage +{ + public interface IStorageConnection : IMiddleware, IAsyncDisposable, IDisposable + { + StorageConnectionMode Behavior { get; } + string ConnectionString { get; } + + Task Initialize(StorageConnectionArgs args); + Task Commit(); + Task Rollback(); + Task Close(); + + Task Execute(IStorageCommand command); + + Task> Query(IStorageCommand command); + + Task Select(IStorageCommand command); + + Task OpenReader(IStorageCommand command); + } +} diff --git a/Connected.Data/Storage/IStorageMiddleware.cs b/Connected.Data/Storage/IStorageMiddleware.cs new file mode 100644 index 0000000..a66a4d5 --- /dev/null +++ b/Connected.Data/Storage/IStorageMiddleware.cs @@ -0,0 +1,14 @@ +using Connected.Data.Storage; +using Connected.Entities; +using Connected.Entities.Storage; + +namespace Connected.Data; +public interface IStorageMiddleware : IQueryProvider, IMiddleware +{ + bool SupportsEntity(Type entityType); + IStorageOperation CreateOperation(TEntity entity) + where TEntity : IEntity; + + IStorageReader OpenReader(IStorageOperation operation, IStorageConnection connection); + IStorageWriter OpenWriter(IStorageOperation operation, IStorageConnection connection); +} diff --git a/Connected.Data/Storage/IStorageReader.cs b/Connected.Data/Storage/IStorageReader.cs new file mode 100644 index 0000000..9530790 --- /dev/null +++ b/Connected.Data/Storage/IStorageReader.cs @@ -0,0 +1,11 @@ +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Data.Storage; + +public interface IStorageReader : IStorageCommand +{ + Task> Query(); + Task Select(); + Task OpenReader(); +} diff --git a/Connected.Data/Storage/IStorageWriter.cs b/Connected.Data/Storage/IStorageWriter.cs new file mode 100644 index 0000000..728b6e0 --- /dev/null +++ b/Connected.Data/Storage/IStorageWriter.cs @@ -0,0 +1,6 @@ +namespace Connected.Data.Storage; + +public interface IStorageWriter : IStorageCommand +{ + Task Execute(); +} diff --git a/Connected.Data/Storage/ReturnValueBinder.cs b/Connected.Data/Storage/ReturnValueBinder.cs new file mode 100644 index 0000000..888ae10 --- /dev/null +++ b/Connected.Data/Storage/ReturnValueBinder.cs @@ -0,0 +1,110 @@ +using System.Data; +using System.Reflection; +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using Connected.Interop; + +namespace Connected.Data.Storage +{ + internal static class ReturnValueBinder + { + public static void Bind(IStorageWriter w, IStorageOperation operation) + { + List properties = null; + + if (w.Operation.Parameters is null) + return; + + foreach (var parameter in w.Operation.Parameters) + { + if (parameter.Direction != ParameterDirection.ReturnValue && parameter.Direction != ParameterDirection.Output) + continue; + + if (parameter.Value == DBNull.Value) + continue; + + if (properties is null) + { + properties = new List(); + + var all = operation.GetType().GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + foreach (var prop in all) + { + if (prop.FindAttribute() is not null) + properties.Add(prop); + } + } + + PropertyInfo property = null; + + if (property is null) + { + foreach (var prop in properties) + { + if (string.Equals(prop.Name, parameter.Name, StringComparison.Ordinal)) + { + property = prop; + break; + } + } + } + + if (property is null) + { + foreach (var prop in properties) + { + if (string.Equals(prop.Name, parameter.Name, StringComparison.OrdinalIgnoreCase)) + { + property = prop; + break; + } + } + } + + if (property is null) + { + var candidates = new List + { + parameter.Name.Replace("@", string.Empty) + }; + + foreach (var prop in properties) + { + foreach (var candidate in candidates) + { + if (string.Equals(prop.Name, candidate, StringComparison.OrdinalIgnoreCase)) + { + property = prop; + break; + } + } + + if (property is not null) + break; + } + } + + if (property is not null) + { + var existingValue = property.GetValue(operation); + var overwriteAtt = property.FindAttribute(); + + switch (overwriteAtt.ValueBehavior) + { + case PropertyValueBehavior.OverwriteDefault: + var defaultValue = property.PropertyType.GetDefault(); + + if (Equals(existingValue, defaultValue)) + property.SetValue(operation, parameter.Value); + break; + case PropertyValueBehavior.AlwaysOverwrite: + property.SetValue(operation, parameter.Value); + break; + } + } + } + } + + } +} diff --git a/Connected.Data/Storage/StorageConnectionArgs.cs b/Connected.Data/Storage/StorageConnectionArgs.cs new file mode 100644 index 0000000..d2be757 --- /dev/null +++ b/Connected.Data/Storage/StorageConnectionArgs.cs @@ -0,0 +1,14 @@ +using Connected.Entities.Storage; + +namespace Connected.Data.Storage; +public sealed class StorageConnectionArgs : EventArgs +{ + public StorageConnectionArgs(string connectionString, StorageConnectionMode behavior) + { + ConnectionString = connectionString; + Behavior = behavior; + } + + public string ConnectionString { get; } + public StorageConnectionMode Behavior { get; } +} diff --git a/Connected.Data/Storage/StorageProvider.cs b/Connected.Data/Storage/StorageProvider.cs new file mode 100644 index 0000000..afd7ca8 --- /dev/null +++ b/Connected.Data/Storage/StorageProvider.cs @@ -0,0 +1,36 @@ +using Connected.Data.DataProtection; +using Connected.Entities; +using Connected.Entities.Storage; +using Connected.Middleware; +using Connected.ServiceModel.Transactions; + +namespace Connected.Data.Storage; + +internal sealed class StorageProvider : IStorageProvider +{ + public StorageProvider(IEntityProtectionService dataProtection, IConnectionProvider connections, ITransactionContext transactions, IMiddlewareService middleware) + { + DataProtection = dataProtection; + Connections = connections; + Transactions = transactions; + Middleware = middleware; + } + private IEntityProtectionService DataProtection { get; } + private IConnectionProvider Connections { get; } + private ITransactionContext Transactions { get; } + private IMiddlewareService Middleware { get; } + /// + /// Opens for reading and writing entities. + /// + /// Type type of the entity to be used. + /// The on which LINQ queries and updates can be performed. + public IStorage Open() + where TEntity : IEntity + { + var result = new EntityStorage(DataProtection, Middleware, Connections, Transactions); + + AsyncUtils.RunSync(result.Initialize); + + return result; + } +} diff --git a/Connected.Data/Update/AggregatedCommandBuilder.cs b/Connected.Data/Update/AggregatedCommandBuilder.cs new file mode 100644 index 0000000..2968580 --- /dev/null +++ b/Connected.Data/Update/AggregatedCommandBuilder.cs @@ -0,0 +1,51 @@ +using System.Collections.Immutable; +using Connected.Entities; +using Connected.Entities.Storage; + +namespace Connected.Data.Update; + +internal class AggregatedCommandBuilder +{ + public StorageOperation? Build(TEntity entity) + { + if (entity is not IEntity ie) + throw new ArgumentException(nameof(entity)); + + switch (ie.State) + { + case State.New: + return BuildInsert(ie); + case State.Default: + return BuildUpdate(ie); + case State.Deleted: + return BuildDelete(ie); + default: + throw new NotSupportedException(); + } + } + + public List Build(ImmutableArray entities) + { + var result = new List(); + + foreach (var entity in entities) + result.Add(Build(entity)); + + return result; + } + + private StorageOperation? BuildInsert(IEntity entity) + { + return new InsertCommandBuilder().Build(entity); + } + + private StorageOperation? BuildUpdate(IEntity entity) + { + return new UpdateCommandBuilder().Build(entity); + } + + private StorageOperation? BuildDelete(IEntity entity) + { + return new DeleteCommandBuilder().Build(entity); + } +} diff --git a/Connected.Data/Update/CommandBuilder.cs b/Connected.Data/Update/CommandBuilder.cs new file mode 100644 index 0000000..1691613 --- /dev/null +++ b/Connected.Data/Update/CommandBuilder.cs @@ -0,0 +1,216 @@ +using Connected.Entities; +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using Connected.Interop; +using System.Data; +using System.Reflection; +using System.Text; + +namespace Connected.Data.Update; + +internal abstract class CommandBuilder +{ + private readonly List _parameters; + private readonly List _whereProperties; + private List _properties; + + protected CommandBuilder() + { + _parameters = new List(); + _whereProperties = new List(); + + Text = new StringBuilder(); + } + + public StorageOperation? Build(IEntity entity) + { + Entity = entity; + + if (TryGetExisting(out StorageOperation? existing)) + { + /* + * We need to rebuild an instance since StorageOperation + * is immutable + */ + var result = new StorageOperation + { + CommandText = existing.CommandText, + CommandTimeout = existing.CommandTimeout, + CommandType = existing.CommandType, + Concurrency = existing.Concurrency + }; + + if (result.Parameters is null) + return result; + + foreach (var parameter in result.Parameters) + { + if (parameter.Direction == ParameterDirection.Input) + { + if (ResolveProperty(parameter.Name) is PropertyInfo property) + { + result.AddParameter(new StorageParameter + { + Value = GetValue(property), + Name = parameter.Name, + Type = parameter.Type, + Direction = parameter.Direction + }); + } + } + } + + return result; + } + + Schema = Entity.GetSchemaAttribute(); + + return OnBuild(); + } + + protected List Properties => _properties ??= GetProperties(); + + protected List WhereProperties => _whereProperties; + + protected string CommandText => Text.ToString(); + + protected IEntity Entity { get; private set; } + + protected SchemaAttribute Schema { get; private set; } + + private StringBuilder Text { get; set; } + + protected abstract StorageOperation OnBuild(); + + protected abstract bool TryGetExisting(out StorageOperation? result); + + protected List Parameters => _parameters; + + protected void Write(string text) + { + Text.Append(text); + } + + protected void Write(char text) + { + Text.Append(text); + } + + protected void WriteLine(string text) + { + Text.AppendLine(text); + } + + protected void Trim() + { + for (var i = Text.Length - 1; i >= 0; i--) + { + if (!Text[i].Equals(',') && !Text[i].Equals('\n') && !Text[i].Equals('\r') && !Text[i].Equals(' ')) + break; + + if (i < Text.Length) + Text.Length = i; + } + } + + protected virtual List GetProperties() + { + return Interop.Properties.GetImplementedProperties(Entity); + } + + protected static bool IsVersion(PropertyInfo property) + { + return property.GetCustomAttribute() is not null; + } + + protected static string ColumnName(PropertyInfo property) + { + var dataMember = property.FindAttribute(); + + return dataMember is null || string.IsNullOrEmpty(dataMember.Member) ? property.Name.ToCamelCase() : dataMember.Member; + } + + protected static DbType ResolveDbType(PropertyInfo property) + { + if (IsVersion(property)) + return DbType.Binary; + + return property.PropertyType.ToDbType(); + } + protected object? GetValue(PropertyInfo property) + { + if (IsNull(property)) + return "NULL"; + + if (IsVersion(property)) + return (byte[])EntityVersion.Parse(property.GetValue(Entity)); + + return GetValue(property.GetValue(Entity), property.PropertyType.ToDbType()); + } + + private static object? GetValue(object value, DbType dbType) + { + switch (dbType) + { + case DbType.Binary: + if (value is byte[] bytes) + return Convert.ToBase64String(bytes); + else + return Convert.ToBase64String(Encoding.UTF8.GetBytes(value.ToString())); + default: + return value; + } + } + + private bool IsNull(PropertyInfo property) + { + var result = property.GetValue(Entity); + + if (result is null) + return true; + + if (property.GetCustomAttribute() is null) + return false; + + var def = Types.GetDefault(property.PropertyType); + + return TypeComparer.Compare(result, def); + } + + protected StorageParameter CreateParameter(PropertyInfo property) + { + return CreateParameter(property, ParameterDirection.Input); + } + + protected StorageParameter CreateParameter(PropertyInfo property, ParameterDirection direction) + { + var columnName = ColumnName(property); + var parameterName = $"@{columnName}"; + + var parameter = new StorageParameter + { + Direction = direction, + Name = parameterName, + Type = ResolveDbType(property), + Value = GetValue(property) + }; + + Parameters.Add(parameter); + + return parameter; + } + + private PropertyInfo ResolveProperty(string parameterName) + { + var propertyName = parameterName[1..]; + var flags = BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic; + + if (Entity.GetType().GetProperty(propertyName.ToPascalCase(), flags) is PropertyInfo property) + return property; + + if (Entity.GetType().GetProperty(propertyName, flags) is PropertyInfo raw) + return raw; + + return null; + } +} diff --git a/Connected.Data/Update/DeleteCommandBuilder.cs b/Connected.Data/Update/DeleteCommandBuilder.cs new file mode 100644 index 0000000..7bcb46b --- /dev/null +++ b/Connected.Data/Update/DeleteCommandBuilder.cs @@ -0,0 +1,57 @@ +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using System.Collections.Concurrent; +using System.Reflection; + +namespace Connected.Data.Update; + +internal class DeleteCommandBuilder : CommandBuilder +{ + private static readonly ConcurrentDictionary _cache; + + static DeleteCommandBuilder() + { + _cache = new(); + } + + private static ConcurrentDictionary Cache => _cache; + + protected override StorageOperation OnBuild() + { + WriteLine($"DELETE [{Schema.Schema}].[{Schema.Name}] ("); + WriteWhere(); + + var result = new StorageOperation { CommandText = CommandText }; + + foreach (var parameter in Parameters) + result.AddParameter(parameter); + + Cache.TryAdd(Entity.GetType().FullName, result); + + return result; + } + + private void WriteWhere() + { + Write("WHERE "); + + foreach (var property in Properties) + { + if (property.GetCustomAttribute() is not null) + { + var columnName = ColumnName(property); + + CreateParameter(property); + + Write($"{ColumnName} = @{ColumnName}"); + } + } + + Write(";"); + } + + protected override bool TryGetExisting(out StorageOperation? result) + { + return Cache.TryGetValue(Entity.GetType().FullName, out result); + } +} diff --git a/Connected.Data/Update/InsertCommandBuilder.cs b/Connected.Data/Update/InsertCommandBuilder.cs new file mode 100644 index 0000000..d502e3c --- /dev/null +++ b/Connected.Data/Update/InsertCommandBuilder.cs @@ -0,0 +1,86 @@ +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using System.Collections.Concurrent; +using System.Data; +using System.Reflection; + +namespace Connected.Data.Update; + +internal class InsertCommandBuilder : CommandBuilder +{ + private static readonly ConcurrentDictionary _cache; + + static InsertCommandBuilder() + { + _cache = new(); + } + + private static ConcurrentDictionary Cache => _cache; + private IStorageParameter PrimaryKeyParameter { get; set; } + + protected override StorageOperation OnBuild() + { + WriteLine($"INSERT [{Schema.Schema}].[{Schema.Name}] ("); + + WriteColumns(); + WriteLine(")"); + Write("VALUES ("); + WriteValues(); + WriteLine(");"); + WriteOutput(); + + var result = new StorageOperation { CommandText = CommandText }; + + foreach (var parameter in Parameters) + result.AddParameter(parameter); + + Cache.TryAdd(Entity.GetType().FullName, result); + + return result; + } + + private void WriteOutput() + { + WriteLine($"SET {PrimaryKeyParameter.Name} = scope_identity();"); + } + + private void WriteColumns() + { + foreach (var property in Properties) + { + if (property.GetCustomAttribute() is not null || IsVersion(property)) + { + if (IsVersion(property)) + continue; + + PrimaryKeyParameter = CreateParameter(property, ParameterDirection.Output); + + continue; + } + + CreateParameter(property); + + Write($"{ColumnName(property)}, "); + } + + Trim(); + } + + private void WriteValues() + { + foreach (var property in Properties) + { + if (property.GetCustomAttribute() is not null || IsVersion(property)) + continue; + + Write($"@{ColumnName(property)}, "); + } + + Trim(); + } + + protected override bool TryGetExisting(out StorageOperation? result) + { + return Cache.TryGetValue(Entity.GetType().FullName, out result); + } +} diff --git a/Connected.Data/Update/UpdateCommandBuilder.cs b/Connected.Data/Update/UpdateCommandBuilder.cs new file mode 100644 index 0000000..29de68a --- /dev/null +++ b/Connected.Data/Update/UpdateCommandBuilder.cs @@ -0,0 +1,81 @@ +using Connected.Entities.Annotations; +using Connected.Entities.Storage; +using System.Collections.Concurrent; +using System.Reflection; + +namespace Connected.Data.Update; + +internal sealed class UpdateCommandBuilder : CommandBuilder +{ + private static readonly ConcurrentDictionary _cache; + + static UpdateCommandBuilder() + { + _cache = new(); + } + + private static ConcurrentDictionary Cache => _cache; + private bool SupportsConcurrency { get; set; } + protected override bool TryGetExisting(out StorageOperation? result) + { + return Cache.TryGetValue(Entity.GetType().FullName, out result); + } + + protected override StorageOperation OnBuild() + { + WriteLine($"UPDATE [{Schema.Schema}].[{Schema.Name}] SET"); + + WriteAssignments(); + WriteWhere(); + + Trim(); + Write(';'); + + var result = new StorageOperation { CommandText = CommandText, Concurrency = SupportsConcurrency ? DataConcurrencyMode.Enabled : DataConcurrencyMode.Disabled }; + + foreach (var parameter in Parameters) + result.AddParameter(parameter); + + Cache.TryAdd(Entity.GetType().FullName, result); + + return result; + } + + private void WriteAssignments() + { + foreach (var property in Properties) + { + if (property.GetCustomAttribute() is not null || IsVersion(property)) + { + if (IsVersion(property)) + SupportsConcurrency = true; + + WhereProperties.Add(property); + + continue; + } + + var parameter = CreateParameter(property); + + WriteLine($"{ColumnName(property)} = {parameter.Name},"); + } + + Trim(); + } + + private void WriteWhere() + { + WriteLine(string.Empty); + + for (var i = 0; i < WhereProperties.Count; i++) + { + var property = WhereProperties[i]; + var parameter = CreateParameter(property); + + if (i == 0) + WriteLine($" WHERE {ColumnName(property)} = {parameter.Name}"); + else + WriteLine($" AND {ColumnName(property)} = {parameter.Name}"); + } + } +} diff --git a/Connected.Entities/Annotations/AssociationAttribute.cs b/Connected.Entities/Annotations/AssociationAttribute.cs new file mode 100644 index 0000000..a8638c1 --- /dev/null +++ b/Connected.Entities/Annotations/AssociationAttribute.cs @@ -0,0 +1,11 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class AssociationAttribute : MemberAttribute + { + public string? KeyMembers { get; set; } + public string? RelatedEntityId { get; set; } + public string? RelatedKeyMembers { get; set; } + public bool IsForeignKey { get; set; } + } +} \ No newline at end of file diff --git a/Connected.Entities/Annotations/BinaryAttribute.cs b/Connected.Entities/Annotations/BinaryAttribute.cs new file mode 100644 index 0000000..97133ec --- /dev/null +++ b/Connected.Entities/Annotations/BinaryAttribute.cs @@ -0,0 +1,15 @@ +namespace Connected.Entities.Annotations +{ + public enum BinaryKind + { + VarBinary = 0, + Binary = 1 + } + + [AttributeUsage(AttributeTargets.Property)] + + public sealed class BinaryAttribute : Attribute + { + public BinaryKind Kind { get; set; } + } +} diff --git a/Connected.Entities/Annotations/DateAttribute.cs b/Connected.Entities/Annotations/DateAttribute.cs new file mode 100644 index 0000000..bbb8045 --- /dev/null +++ b/Connected.Entities/Annotations/DateAttribute.cs @@ -0,0 +1,18 @@ +namespace Connected.Entities.Annotations +{ + public enum DateKind + { + NotSet = 0, + Date = 1, + DateTime = 2, + DateTime2 = 3, + SmallDateTime = 4, + Time = 5 + } + [AttributeUsage(AttributeTargets.Property)] + public sealed class DateAttribute : Attribute + { + public DateKind Kind { get; set; } = DateKind.DateTime; + public int Precision { get; set; } + } +} diff --git a/Connected.Entities/Annotations/DefaultAttribute.cs b/Connected.Entities/Annotations/DefaultAttribute.cs new file mode 100644 index 0000000..6e481fe --- /dev/null +++ b/Connected.Entities/Annotations/DefaultAttribute.cs @@ -0,0 +1,12 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class DefaultAttribute : Attribute + { + public DefaultAttribute(object value) + { + Value = value; + } + public object? Value { get; } + } +} diff --git a/Connected.Entities/Annotations/ETagAttribute.cs b/Connected.Entities/Annotations/ETagAttribute.cs new file mode 100644 index 0000000..3b037d7 --- /dev/null +++ b/Connected.Entities/Annotations/ETagAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class ETagAttribute : Attribute + { + } +} diff --git a/Connected.Entities/Annotations/EntityAttribute.cs b/Connected.Entities/Annotations/EntityAttribute.cs new file mode 100644 index 0000000..54aa60c --- /dev/null +++ b/Connected.Entities/Annotations/EntityAttribute.cs @@ -0,0 +1,10 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false)] + + public sealed class EntityAttribute : MappingAttribute + { + public string? Id { get; set; } + public Type? RuntimeType { get; set; } + } +} diff --git a/Connected.Entities/Annotations/IndexAttribute.cs b/Connected.Entities/Annotations/IndexAttribute.cs new file mode 100644 index 0000000..2c6c813 --- /dev/null +++ b/Connected.Entities/Annotations/IndexAttribute.cs @@ -0,0 +1,8 @@ +namespace Connected.Entities.Annotations +{ + public sealed class IndexAttribute : Attribute + { + public bool Unique { get; set; } + public string? Name { get; set; } + } +} diff --git a/Connected.Entities/Annotations/LengthAttribute.cs b/Connected.Entities/Annotations/LengthAttribute.cs new file mode 100644 index 0000000..a13457c --- /dev/null +++ b/Connected.Entities/Annotations/LengthAttribute.cs @@ -0,0 +1,12 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class LengthAttribute : Attribute + { + public LengthAttribute(int value) + { + Value = value; + } + public int Value { get; } + } +} diff --git a/Connected.Entities/Annotations/MappingAttribute.cs b/Connected.Entities/Annotations/MappingAttribute.cs new file mode 100644 index 0000000..303ddd2 --- /dev/null +++ b/Connected.Entities/Annotations/MappingAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Property | AttributeTargets.Field)] + public abstract class MappingAttribute : Attribute + { + } +} diff --git a/Connected.Entities/Annotations/MemberAttribute.cs b/Connected.Entities/Annotations/MemberAttribute.cs new file mode 100644 index 0000000..190572f --- /dev/null +++ b/Connected.Entities/Annotations/MemberAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities.Annotations +{ + public class MemberAttribute : MappingAttribute + { + public string? Member { get; set; } + } +} diff --git a/Connected.Entities/Annotations/NullableAttribute.cs b/Connected.Entities/Annotations/NullableAttribute.cs new file mode 100644 index 0000000..d9a8e9b --- /dev/null +++ b/Connected.Entities/Annotations/NullableAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities.Annotations; + +[AttributeUsage(AttributeTargets.Property)] +public sealed class NullableAttribute : Attribute +{ + public bool IsNullable { get; set; } = true; +} diff --git a/Connected.Entities/Annotations/NumericAttribute.cs b/Connected.Entities/Annotations/NumericAttribute.cs new file mode 100644 index 0000000..5967660 --- /dev/null +++ b/Connected.Entities/Annotations/NumericAttribute.cs @@ -0,0 +1,9 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class NumericAttribute : Attribute + { + public int Percision { get; set; } + public int Scale { get; set; } + } +} diff --git a/Connected.Entities/Annotations/PersistenceAttribute.cs b/Connected.Entities/Annotations/PersistenceAttribute.cs new file mode 100644 index 0000000..2bb0ea8 --- /dev/null +++ b/Connected.Entities/Annotations/PersistenceAttribute.cs @@ -0,0 +1,22 @@ +namespace Connected.Entities.Annotations +{ + [Flags] + public enum ColumnPersistence + { + InMemory = 0, + Read = 1, + Write = 2, + ReadWrite = 3 + } + + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Class)] + public sealed class PersistenceAttribute : Attribute + { + public ColumnPersistence Persistence { get; set; } + + public bool IsReadOnly => (Persistence & ColumnPersistence.Read) == ColumnPersistence.Read; + public bool IsWriteOnly => (Persistence & ColumnPersistence.Write) == ColumnPersistence.Write; + public bool IsReadWrite => (Persistence & ColumnPersistence.ReadWrite) == ColumnPersistence.ReadWrite; + public bool IsVirtual => Persistence == ColumnPersistence.InMemory; + } +} diff --git a/Connected.Entities/Annotations/PrimaryKeyAttribute.cs b/Connected.Entities/Annotations/PrimaryKeyAttribute.cs new file mode 100644 index 0000000..e6d94fe --- /dev/null +++ b/Connected.Entities/Annotations/PrimaryKeyAttribute.cs @@ -0,0 +1,8 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class PrimaryKeyAttribute : Attribute + { + public bool Identity { get; set; } = true; + } +} diff --git a/Connected.Entities/Annotations/ReturnValueAttribute.cs b/Connected.Entities/Annotations/ReturnValueAttribute.cs new file mode 100644 index 0000000..35d7a98 --- /dev/null +++ b/Connected.Entities/Annotations/ReturnValueAttribute.cs @@ -0,0 +1,13 @@ +namespace Connected.Entities.Annotations +{ + public enum PropertyValueBehavior + { + OverwriteDefault = 1, + AlwaysOverwrite = 2, + } + [AttributeUsage(AttributeTargets.Property)] + public sealed class ReturnValueAttribute : Attribute + { + public PropertyValueBehavior ValueBehavior { get; set; } = PropertyValueBehavior.OverwriteDefault; + } +} diff --git a/Connected.Entities/Annotations/SchemaAttribute.cs b/Connected.Entities/Annotations/SchemaAttribute.cs new file mode 100644 index 0000000..0b5fce8 --- /dev/null +++ b/Connected.Entities/Annotations/SchemaAttribute.cs @@ -0,0 +1,17 @@ +namespace Connected.Entities.Annotations +{ + public abstract class SchemaAttribute : MappingAttribute + { + public const string DefaultSchema = "dbo"; + public const string SchemaTypeTable = "Table"; + /* + * sys schema is reserved for system views and tables by sql. + */ + public const string SysSchema = "sxs"; + public const string TypesSchema = "typ"; + public string? Id { get; set; } + public string? Name { get; set; } + public string? Schema { get; set; } = DefaultSchema; + + } +} diff --git a/Connected.Entities/Annotations/StringAttribute.cs b/Connected.Entities/Annotations/StringAttribute.cs new file mode 100644 index 0000000..76e006d --- /dev/null +++ b/Connected.Entities/Annotations/StringAttribute.cs @@ -0,0 +1,16 @@ +namespace Connected.Entities.Annotations +{ + public enum StringKind + { + NVarChar = 0, + VarChar = 1, + Char = 2, + NChar = 3 + } + + [AttributeUsage(AttributeTargets.Property)] + public sealed class StringAttribute : Attribute + { + public StringKind Kind { get; set; } + } +} diff --git a/Connected.Entities/Annotations/TableAttribute.cs b/Connected.Entities/Annotations/TableAttribute.cs new file mode 100644 index 0000000..4202be0 --- /dev/null +++ b/Connected.Entities/Annotations/TableAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface | AttributeTargets.Method, AllowMultiple = false)] + public class TableAttribute : SchemaAttribute + { + } +} diff --git a/Connected.Entities/Annotations/TableExtensionAttribute.cs b/Connected.Entities/Annotations/TableExtensionAttribute.cs new file mode 100644 index 0000000..7282219 --- /dev/null +++ b/Connected.Entities/Annotations/TableExtensionAttribute.cs @@ -0,0 +1,10 @@ +namespace Connected.Entities.Annotations +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = true)] + public sealed class TableExtensionAttribute : SchemaAttribute + { + public string? KeyColumns { get; set; } + public string? RelatedTableId { get; set; } + public string? RelatedKeyColumns { get; set; } + } +} diff --git a/Connected.Entities/Caching/EntityCacheClient.cs b/Connected.Entities/Caching/EntityCacheClient.cs new file mode 100644 index 0000000..78183f3 --- /dev/null +++ b/Connected.Entities/Caching/EntityCacheClient.cs @@ -0,0 +1,127 @@ +using Connected.Caching; +using Connected.Data; +using Connected.Entities.Concurrency; +using Connected.Entities.Storage; +using Connected.Interop; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using System.Collections.Immutable; + +namespace Connected.Entities.Caching; + +public abstract class EntityCacheClient : StatefulCacheClient, IEntityCacheClient + where TEntity : class, IPrimaryKey, IEntity + where TPrimaryKey : notnull +{ + protected EntityCacheClient(IEntityCacheContext context, string key) : base(context.Cache, key) + { + Context = context; + } + + private IEntityCacheContext Context { get; } + + protected override sealed async Task OnInitializing() + { + using var ctx = Context.ContextProvider.Create(); + + var transaction = ctx.GetService(); + + try + { + if (await OnInitializing(ctx) is ImmutableList ds) + { + foreach (var r in ds) + Set(r.Id, r, TimeSpan.Zero); + } + + if (transaction is not null) + await transaction.Commit(); + } + catch + { + if (transaction is not null) + await transaction.Rollback(); + + throw; + } + } + + protected virtual async Task?> OnInitializing(IContext context) + { + if (context.GetService() is not IStorageProvider db) + return default; + + return await (from dc in db.Open() + select dc).AsEntities(); + } + + protected override async Task OnInvalidate(TPrimaryKey id) + { + using var ctx = Context.ContextProvider.Create(); + var transaction = ctx.GetService(); + + try + { + if (OnInvalidating(ctx, id) is TEntity entity && entity is IPrimaryKey pk) + Set(pk.Id, entity, TimeSpan.Zero); + + if (transaction is not null) + await transaction.Commit(); + } + catch + { + if (transaction is not null) + await transaction.Rollback(); + + throw; + } + } + + protected virtual async Task OnInvalidating(IContext context, TPrimaryKey id) + { + if (context.GetService() is not IStorageProvider provider) + return default; + + return await (from dc in provider.Open() + where TypeComparer.Compare(dc.Id, id) + select dc).AsEntity(); + } + + async Task IEntityCacheClient.Refresh(TPrimaryKey id) + { + await Refresh(id); + } + + async Task IEntityCacheClient.Remove(TPrimaryKey id) + { + await Remove(id); + } + + protected override void Set(TPrimaryKey id, TEntity instance) + { + Set(id, instance, TimeSpan.Zero); + } + + protected override void Set(TPrimaryKey id, TEntity instance, TimeSpan duration) + { + if (instance is IConcurrentEntity concurrent) + { + if (Get(id) is TEntity existing && existing is IConcurrentEntity existingConcurrent) + { + lock (existingConcurrent) + { + if (existingConcurrent.Sync != concurrent.Sync) + throw new InvalidOperationException(SR.ErrConcurrent); + + concurrent.GetType().GetProperty(nameof(IConcurrentEntity.Sync))?.SetValue(concurrent, concurrent.Sync + 1); + + Set(id, instance, duration); + + return; + } + } + } + + base.Set(id, instance, duration); + } +} diff --git a/Connected.Entities/Caching/EntityCacheContext.cs b/Connected.Entities/Caching/EntityCacheContext.cs new file mode 100644 index 0000000..4a560bd --- /dev/null +++ b/Connected.Entities/Caching/EntityCacheContext.cs @@ -0,0 +1,17 @@ +using Connected.Caching; +using Connected.ServiceModel; + +namespace Connected.Entities.Caching; + +internal class EntityCacheContext : IEntityCacheContext +{ + public EntityCacheContext(ICachingService cachingService, IContextProvider contextProvider) + { + Cache = cachingService; + ContextProvider = contextProvider; + } + + public ICachingService Cache { get; } + + public IContextProvider ContextProvider { get; } +} diff --git a/Connected.Entities/Caching/IEntityCacheClient.cs b/Connected.Entities/Caching/IEntityCacheClient.cs new file mode 100644 index 0000000..756fc5a --- /dev/null +++ b/Connected.Entities/Caching/IEntityCacheClient.cs @@ -0,0 +1,9 @@ +using Connected.Caching; + +namespace Connected.Entities.Caching; + +public interface IEntityCacheClient : IStatefulCacheClient +{ + Task Refresh(TKey id); + Task Remove(TKey id); +} diff --git a/Connected.Entities/Caching/IEntityCacheContext.cs b/Connected.Entities/Caching/IEntityCacheContext.cs new file mode 100644 index 0000000..4f0c504 --- /dev/null +++ b/Connected.Entities/Caching/IEntityCacheContext.cs @@ -0,0 +1,10 @@ +using Connected.Caching; +using Connected.ServiceModel; + +namespace Connected.Entities.Caching; + +public interface IEntityCacheContext +{ + ICachingService Cache { get; } + IContextProvider ContextProvider { get; } +} diff --git a/Connected.Entities/Concurrency/ConcurrentEntity.cs b/Connected.Entities/Concurrency/ConcurrentEntity.cs new file mode 100644 index 0000000..ddac3c8 --- /dev/null +++ b/Connected.Entities/Concurrency/ConcurrentEntity.cs @@ -0,0 +1,17 @@ +using Connected.Entities.Annotations; +using Connected.Entities.Consistency; + +namespace Connected.Entities.Concurrency; + +public abstract record ConcurrentEntity : ConsistentEntity, IConcurrentEntity + where TPrimaryKey : notnull +{ + private int _sync = 0; + + [Persistence(Persistence = ColumnPersistence.InMemory)] + public int Sync + { + get => _sync; + set => Interlocked.Exchange(ref _sync, value); + } +} diff --git a/Connected.Entities/Concurrency/IConcurrentEntity.cs b/Connected.Entities/Concurrency/IConcurrentEntity.cs new file mode 100644 index 0000000..024d7fb --- /dev/null +++ b/Connected.Entities/Concurrency/IConcurrentEntity.cs @@ -0,0 +1,33 @@ +using Connected.Entities.Consistency; + +namespace Connected.Entities.Concurrency; + +/// +/// This entity is primarly used when cached in memory and access to the entity is frequent with +/// updates. The ensures that threads don't overwrite +/// values from other threads by using the property. +/// +/// +/// While ensures database consistency, +/// ensures application consistency. The is not thread safe but ensures +/// that any writes are rejected if the thread tries to write entity with invalid property value. +/// +/// +public interface IConcurrentEntity : IConsistentEntity + where TPrimaryKey : notnull +{ + /// + /// The synchronization value used when comparing if the write to the entity is made with the latest + /// version. Entities are immutable but they can be replaced in Cache with newer instances. The cache tipically + /// ensures that entities can't be overwritten with out of date values. + /// + /// + /// In Queue messages, all messages are stored in memory and multiple threads perform dequeue. Since dequeue means + /// overwriting some data and since the entities are immutable (except this entity and only the property) the operaton results with overwriting the entire + /// entity in cache. If two or more thready do it in the same time, they could accidentally overwrite values from + /// each other. The cache ensures that the current entity has the same value as the updating entity. If the write + /// occurred in the mean time it would result incrementing the value which would cause any subsequent + /// writes with the same originating entity would fail. + /// + int Sync { get; set; } +} diff --git a/Connected.Entities/Connected.Entities.csproj b/Connected.Entities/Connected.Entities.csproj new file mode 100644 index 0000000..38dee94 --- /dev/null +++ b/Connected.Entities/Connected.Entities.csproj @@ -0,0 +1,30 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Entities/Consistency/ConsistentEntity.cs b/Connected.Entities/Consistency/ConsistentEntity.cs new file mode 100644 index 0000000..75e1a4a --- /dev/null +++ b/Connected.Entities/Consistency/ConsistentEntity.cs @@ -0,0 +1,12 @@ +using Connected.Annotations; +using Connected.Entities.Annotations; +using System.Text.Json.Serialization; + +namespace Connected.Entities.Consistency; + +public abstract record ConsistentEntity : Entity + where TPrimaryKey : notnull +{ + [Ordinal(10000), ETag, JsonIgnore] + public string? ETag { get; init; } +} diff --git a/Connected.Entities/Consistency/IConsistentEntity.cs b/Connected.Entities/Consistency/IConsistentEntity.cs new file mode 100644 index 0000000..8097d98 --- /dev/null +++ b/Connected.Entities/Consistency/IConsistentEntity.cs @@ -0,0 +1,8 @@ +namespace Connected.Entities.Consistency +{ + public interface IConsistentEntity : IEntity + where T : notnull + { + string? ETag { get; } + } +} diff --git a/Connected.Entities/Containers/ContainerEntity.cs b/Connected.Entities/Containers/ContainerEntity.cs new file mode 100644 index 0000000..92fa472 --- /dev/null +++ b/Connected.Entities/Containers/ContainerEntity.cs @@ -0,0 +1,15 @@ +using Connected.Annotations; +using Connected.Entities.Annotations; +using Connected.Entities.Consistency; + +namespace Connected.Entities.Containers; +public abstract record ContainerEntity : ConsistentEntity + where TPrimaryKey : notnull +{ + + [Ordinal(-100), Length(128)] + public string Entity { get; init; } = default!; + + [Ordinal(-99), Length(128)] + public string EntityId { get; init; } = default!; +} diff --git a/Connected.Entities/EntitiesExtensions.cs b/Connected.Entities/EntitiesExtensions.cs new file mode 100644 index 0000000..fca3f4a --- /dev/null +++ b/Connected.Entities/EntitiesExtensions.cs @@ -0,0 +1,258 @@ +using System.Collections.Immutable; +using System.Linq.Expressions; +using System.Reflection; +using Connected.Data; +using Connected.Entities.Annotations; +using Connected.Entities.Query; +using Connected.Interop; +using Connected.Notifications; +using Connected.ServiceModel; + +namespace Connected.Entities; + +public static class EntitiesExtensions +{ + /// + /// Converts the object into entity and overwrites the provided properties. + /// + /// + /// All provided arguments are used when overwriting properties in the order they are specified. This means + /// the value from the last defined property is used when setting the entity's value. + /// + /// The type of the entity to create. + /// The arguments containing base property set. + /// The state modifier to which entity is set. + /// An array of additional modifier objects providing modified values. + /// A new instance of the entity with modified values. + public static TEntity AsEntity(this IDto args, State state, params object[] sources) + where TEntity : IEntity + { + if (typeof(TEntity).CreateInstance() is not TEntity instance) + throw new NullReferenceException(typeof(TEntity).Name); + + return Merge(instance, args, state, sources); + } + + public static TArgs AsArguments(this IEntity entity) where TArgs : IDto + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, entity); + } + + public static TArgs AsArguments(this IPrimaryKey entity) + where TArgs : IDto + where TPrimaryKey : notnull + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, entity); + } + + public static TArgs Patch(this IDto args, TEntity? entity, params object[] sources) + where TArgs : IDto + where TEntity : IEntity + { + if (entity is null) + { + var e = args.AsEntity(State.Default); + + return e.AsArguments(args); + } + + return Merge(entity, args, State.Default, sources).AsArguments(); + } + + public static TArgs AsArguments(this IEntity entity, params object[] sources) where TArgs : IDto + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, entity, sources); + } + + public static TArgs AsEventArguments(this IEntity entity) where TArgs : IEventArgs + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, entity); + } + + public static TEntity Merge(this TEntity existing, IDto modifier, State state, params object[] sources) + where TEntity : IEntity + { + var newEntity = Activator.CreateInstance(); + + return Serializer.Merge(newEntity, existing, modifier, new StateModifier { State = state }, sources); + } + + public static async Task> AsEntities(this IQueryable source, CancellationToken cancellationToken = default) + { + if (source is null) + return ImmutableList.Empty; + + var list = new List(); + + await foreach (var element in source.AsAsyncEnumerable().WithCancellation(cancellationToken)) + list.Add(element); + + return list.ToImmutableList(); + } + + public static async Task> AsEntities(this IEnumerable source) + { + if (source is null) + return ImmutableList.Empty; + + await Task.CompletedTask; + + return source.ToImmutableList(); + } + + public static async Task AsEntity(this IQueryable source, CancellationToken cancellationToken = default) + { + if (source is null) + return default; + + await Task.CompletedTask; + + return Execute(QueryableMethods.SingleOrDefaultWithoutPredicate, source, cancellationToken); + } + + public static async Task AsEntity(this IEnumerable source) + { + if (source is null) + return default; + + await Task.CompletedTask; + + return source.FirstOrDefault(); + } + + public static IEnumerable WithArguments(this IEnumerable source, QueryArgs args) + { + return source.WithOrderBy(args).WithPaging(args); + } + + private static IEnumerable WithPaging(this IEnumerable source, QueryArgs args) + { + if (args.Paging.Size < 1) + return source; + + return source.Skip((args.Paging.Index - 1) * args.Paging.Size) + .Take(args.Paging.Size); + } + + private static IEnumerable WithOrderBy(this IEnumerable entities, QueryArgs args) + { + if (entities.AsQueryable() as IOrderedQueryable is not IOrderedQueryable result) + return entities; + + var top = true; + + foreach (var criteria in args.OrderBy) + { + if (top) + { + if (criteria.Mode == OrderByMode.Ascending) + result = result.OrderBy(ResolvePropertyPredicate(criteria.Property)); + else + result = result.OrderByDescending(ResolvePropertyPredicate(criteria.Property)); + } + else + { + if (criteria.Mode == OrderByMode.Ascending) + result = result.ThenBy(ResolvePropertyPredicate(criteria.Property)); + else + result = result.ThenByDescending(ResolvePropertyPredicate(criteria.Property)); + } + + top = false; + } + + return result; + } + + private static Expression> ResolvePropertyPredicate(string propToOrder) + { + var param = Expression.Parameter(typeof(T)); + var memberAccess = Expression.Property(param, propToOrder); + var convertedMemberAccess = Expression.Convert(memberAccess, typeof(object)); + var orderPredicate = Expression.Lambda>(convertedMemberAccess, param); + + return orderPredicate; + } + + private static TResult? Execute(MethodInfo operatorMethodInfo, IQueryable source, Expression? expression, CancellationToken cancellationToken = default) + { + if (source.Provider is IAsyncQueryProvider provider) + { + if (operatorMethodInfo.IsGenericMethod) + { + operatorMethodInfo = operatorMethodInfo.GetGenericArguments().Length == 2 + ? operatorMethodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult).GetGenericArguments().Single()) + : operatorMethodInfo.MakeGenericMethod(typeof(TSource)); + } + + return (TResult)provider.Execute(Expression.Call(instance: null, method: operatorMethodInfo, arguments: expression == null + ? new[] { source.Expression } + : new[] { source.Expression, expression }), cancellationToken); + } + + throw new InvalidOperationException(); + } + + private static TResult? Execute(MethodInfo operatorMethodInfo, IQueryable source, CancellationToken cancellationToken = default) + { + return Execute(operatorMethodInfo, source, (Expression?)null, cancellationToken); + } + + public static PropertyInfo? PrimaryKeyProperty(this IEntity entity) + { + foreach (var property in entity.GetType().GetInheritedProperites()) + { + if (property.FindAttribute() is not null) + return property; + } + + return null; + } + + public static object? PrimaryKeyValue(this IEntity entity) + { + if (PrimaryKeyProperty(entity) is not PropertyInfo property) + return null; + + return property.GetValue(entity); + } + public static string EntityId(this IEntity entity) + { + var attribute = entity.GetSchemaAttribute(); + + return $"{attribute.Schema}.{attribute.Name}"; + } + public static SchemaAttribute GetSchemaAttribute(this IEntity entity) + { + var definedAttribute = entity.GetType().GetCustomAttribute(); + + if (definedAttribute is not null && definedAttribute.Schema is not null && definedAttribute.Name is not null) + return definedAttribute; + + var schema = string.IsNullOrEmpty(definedAttribute?.Schema) ? SchemaAttribute.DefaultSchema : definedAttribute.Schema; + var name = string.IsNullOrEmpty(definedAttribute?.Name) ? entity.GetType().Name.ToPascalCase() : definedAttribute.Name; + + return new TableAttribute + { + Id = definedAttribute?.Id, + Name = name, + Schema = schema + }; + } + + public static IAsyncEnumerable AsAsyncEnumerable(this IQueryable source) + { + if (source is IAsyncEnumerable asyncEnumerable) + return asyncEnumerable; + + throw new InvalidOperationException(); + } +} diff --git a/Connected.Entities/EntitiesStartup.cs b/Connected.Entities/EntitiesStartup.cs new file mode 100644 index 0000000..bbe5b05 --- /dev/null +++ b/Connected.Entities/EntitiesStartup.cs @@ -0,0 +1,15 @@ +using Connected.Annotations; +using Connected.Entities.Caching; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Entities; + +internal class EntitiesStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(IEntityCacheContext), typeof(EntityCacheContext)); + } +} diff --git a/Connected.Entities/Entity.cs b/Connected.Entities/Entity.cs new file mode 100644 index 0000000..2c21c4e --- /dev/null +++ b/Connected.Entities/Entity.cs @@ -0,0 +1,25 @@ +using Connected.Annotations; +using Connected.Caching.Annotations; +using Connected.Entities.Annotations; +using System.ComponentModel; +using System.Text.Json.Serialization; + +namespace Connected.Entities; + +public abstract record Entity : IEntity +{ + [DefaultValue(State.Default), JsonIgnore, Persistence(Persistence = ColumnPersistence.InMemory)] + public State State { get; init; } +} + +public abstract record Entity : Entity, IEntity + where T : notnull +{ + protected Entity() + { + Id = default!; + } + + [PrimaryKey, CacheKey, ReturnValue, Ordinal(-10000)] + public virtual T Id { get; init; } +} diff --git a/Connected.Entities/EntityContainer.cs b/Connected.Entities/EntityContainer.cs new file mode 100644 index 0000000..ef11044 --- /dev/null +++ b/Connected.Entities/EntityContainer.cs @@ -0,0 +1,17 @@ +using Connected.Annotations; +using Connected.Data; +using Connected.Entities.Annotations; +using Connected.Entities.Consistency; + +namespace Connected.Entities; +/// > +public abstract record EntityContainer : ConsistentEntity, IEntityContainer + where TPrimaryKey : notnull +{ + /// > + [Ordinal(-10), Length(128)] + public string Entity { get; init; } = default!; + /// > + [Ordinal(-9), Length(128)] + public string EntityId { get; init; } = default!; +} diff --git a/Connected.Entities/EntityExceptions.cs b/Connected.Entities/EntityExceptions.cs new file mode 100644 index 0000000..8b4073c --- /dev/null +++ b/Connected.Entities/EntityExceptions.cs @@ -0,0 +1,10 @@ +namespace Connected.Entities +{ + public static class EntityExceptions + { + public static InvalidCastException EntityCastException(Type componentType, Type entityType) + { + return new InvalidCastException($"{SR.ErrEntityCreate} ({componentType.Name}->{entityType.Name})"); + } + } +} diff --git a/Connected.Entities/IEntity.cs b/Connected.Entities/IEntity.cs new file mode 100644 index 0000000..62e84f1 --- /dev/null +++ b/Connected.Entities/IEntity.cs @@ -0,0 +1,20 @@ +using Connected.Data; + +namespace Connected.Entities; + +public enum State : byte +{ + Default = 0, + New = 1, + Deleted = 2 +} + +public interface IEntity +{ + State State { get; init; } +} + +public interface IEntity : IEntity, IPrimaryKey + where T : notnull +{ +} diff --git a/Connected.Entities/Query/IAsyncQueryProvider.cs b/Connected.Entities/Query/IAsyncQueryProvider.cs new file mode 100644 index 0000000..75a1b78 --- /dev/null +++ b/Connected.Entities/Query/IAsyncQueryProvider.cs @@ -0,0 +1,9 @@ +using System.Linq.Expressions; + +namespace Connected.Entities.Query +{ + public interface IAsyncQueryProvider : IQueryProvider + { + object Execute(Expression expression, CancellationToken cancellationToken = default); + } +} diff --git a/Connected.Entities/QueryableMethods.cs b/Connected.Entities/QueryableMethods.cs new file mode 100644 index 0000000..0888ef0 --- /dev/null +++ b/Connected.Entities/QueryableMethods.cs @@ -0,0 +1,29 @@ +using System.Reflection; + +namespace Connected.Entities +{ + internal static class QueryableMethods + { + static QueryableMethods() + { + var queryableMethodGroups = typeof(Queryable) + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) + .GroupBy(mi => mi.Name) + .ToDictionary(e => e.Key, l => l.ToList()); + + SingleWithoutPredicate = GetMethod(nameof(Queryable.Single), 1, types => new[] { typeof(IQueryable<>).MakeGenericType(types[0]) }); + SingleOrDefaultWithoutPredicate = GetMethod(nameof(Queryable.SingleOrDefault), 1, types => new[] { typeof(IQueryable<>).MakeGenericType(types[0]) }); + + MethodInfo GetMethod(string name, int genericParameterCount, Func parameterGenerator) + { + return queryableMethodGroups[name].Single(mi => ((genericParameterCount == 0 && !mi.IsGenericMethod) + || (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount)) + && mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : Array.Empty()))); + } + } + + public static MethodInfo SingleWithoutPredicate { get; } + public static MethodInfo SingleOrDefaultWithoutPredicate { get; } + + } +} diff --git a/Connected.Entities/SR.Designer.cs b/Connected.Entities/SR.Designer.cs new file mode 100644 index 0000000..5daecd2 --- /dev/null +++ b/Connected.Entities/SR.Designer.cs @@ -0,0 +1,81 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Entities { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Entities.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Concurrent access exception. + /// + internal static string ErrConcurrent { + get { + return ResourceManager.GetString("ErrConcurrent", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot create entity instance. + /// + internal static string ErrEntityCreate { + get { + return ResourceManager.GetString("ErrEntityCreate", resourceCulture); + } + } + } +} diff --git a/Connected.Entities/SR.resx b/Connected.Entities/SR.resx new file mode 100644 index 0000000..75af40e --- /dev/null +++ b/Connected.Entities/SR.resx @@ -0,0 +1,126 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Concurrent access exception + + + Cannot create entity instance + + \ No newline at end of file diff --git a/Connected.Entities/StateModifier.cs b/Connected.Entities/StateModifier.cs new file mode 100644 index 0000000..b420df1 --- /dev/null +++ b/Connected.Entities/StateModifier.cs @@ -0,0 +1,7 @@ +namespace Connected.Entities +{ + internal class StateModifier + { + public State State { get; init; } + } +} diff --git a/Connected.Entities/Storage/IStorage.cs b/Connected.Entities/Storage/IStorage.cs new file mode 100644 index 0000000..9fe9275 --- /dev/null +++ b/Connected.Entities/Storage/IStorage.cs @@ -0,0 +1,32 @@ +using System.Collections; +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Entities.Storage; +/// +/// Defines the read and write operations on the supported storage providers. +/// +/// >The type of the entitiy on which operations are performed. +public interface IStorage : IQueryable, IQueryable, IEnumerable, IEnumerable, IOrderedQueryable, IOrderedQueryable + where TEntity : IEntity +{ + /// + /// Updates the entity against the underlying storage. + /// + /// + /// + Task Update(TEntity? entity); + + Task Update(TEntity? entity, TArgs args, Func>? concurrencyRetrying) + where TArgs : IDto; + Task Update(TEntity? entity, TArgs args, Func>? concurrencyRetrying, Func>? merging) + where TArgs : IDto; + + Task> OpenReaders(StorageContextArgs args); + + Task Execute(StorageContextArgs args); + + Task> Query(StorageContextArgs args); + + Task Select(StorageContextArgs args); +} diff --git a/Connected.Entities/Storage/IStorageOperation.cs b/Connected.Entities/Storage/IStorageOperation.cs new file mode 100644 index 0000000..d5f566c --- /dev/null +++ b/Connected.Entities/Storage/IStorageOperation.cs @@ -0,0 +1,21 @@ +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Entities.Storage +{ + public enum DataConcurrencyMode + { + Enabled = 1, + Disabled = 2, + } + + public interface IStorageOperation + { + string CommandText { get; } + CommandType CommandType { get; } + + ImmutableList? Parameters { get; } + int CommandTimeout { get; } + DataConcurrencyMode Concurrency { get; } + } +} diff --git a/Connected.Entities/Storage/IStorageParameter.cs b/Connected.Entities/Storage/IStorageParameter.cs new file mode 100644 index 0000000..f056584 --- /dev/null +++ b/Connected.Entities/Storage/IStorageParameter.cs @@ -0,0 +1,12 @@ +using System.Data; + +namespace Connected.Entities.Storage +{ + public interface IStorageParameter + { + string? Name { get; init; } + object? Value { get; set; } + ParameterDirection Direction { get; init; } + DbType Type { get; init; } + } +} diff --git a/Connected.Entities/Storage/IStorageProvider.cs b/Connected.Entities/Storage/IStorageProvider.cs new file mode 100644 index 0000000..a7b6747 --- /dev/null +++ b/Connected.Entities/Storage/IStorageProvider.cs @@ -0,0 +1,13 @@ +namespace Connected.Entities.Storage; + +public enum StorageConnectionMode +{ + Shared = 1, + Isolated = 2 +} + +public interface IStorageProvider +{ + IStorage Open() + where TEntity : IEntity; +} diff --git a/Connected.Entities/Storage/StorageContextArgs.cs b/Connected.Entities/Storage/StorageContextArgs.cs new file mode 100644 index 0000000..1d29150 --- /dev/null +++ b/Connected.Entities/Storage/StorageContextArgs.cs @@ -0,0 +1,14 @@ +namespace Connected.Entities.Storage; + +public class StorageContextArgs : EventArgs +{ + public StorageContextArgs(IStorageOperation operation) + { + if (operation is null) + throw new ArgumentException(null, nameof(operation)); + + Operation = operation; + } + + public IStorageOperation Operation { get; } +} diff --git a/Connected.Entities/Storage/StorageOperation.cs b/Connected.Entities/Storage/StorageOperation.cs new file mode 100644 index 0000000..972eec6 --- /dev/null +++ b/Connected.Entities/Storage/StorageOperation.cs @@ -0,0 +1,29 @@ +using System.Collections.Immutable; +using System.Data; + +namespace Connected.Entities.Storage; + +public class StorageOperation : IStorageOperation +{ + private List _parameters; + + public StorageOperation() + { + _parameters = new(); + } + + public string CommandText { get; init; } + + public CommandType CommandType { get; init; } = CommandType.Text; + + public ImmutableList? Parameters => _parameters.ToImmutableList(); + + public int CommandTimeout { get; init; } = 30; + + public DataConcurrencyMode Concurrency { get; init; } = DataConcurrencyMode.Enabled; + + public void AddParameter(IStorageParameter parameter) + { + _parameters.Add(parameter); + } +} diff --git a/Connected.Entities/Storage/StorageParameter.cs b/Connected.Entities/Storage/StorageParameter.cs new file mode 100644 index 0000000..d13d088 --- /dev/null +++ b/Connected.Entities/Storage/StorageParameter.cs @@ -0,0 +1,11 @@ +using System.Data; + +namespace Connected.Entities.Storage; + +public class StorageParameter : IStorageParameter +{ + public string? Name { get; init; } + public object? Value { get; set; } + public ParameterDirection Direction { get; init; } = ParameterDirection.Input; + public DbType Type { get; init; } = DbType.String; +} diff --git a/Connected.Expressions/Collections/CollectionExtensions.cs b/Connected.Expressions/Collections/CollectionExtensions.cs new file mode 100644 index 0000000..047916d --- /dev/null +++ b/Connected.Expressions/Collections/CollectionExtensions.cs @@ -0,0 +1,24 @@ +using System.Collections.ObjectModel; + +namespace Connected.Expressions.Collections; + +internal static class Extensions +{ + public static ReadOnlyCollection ToReadOnly(this IEnumerable sequence) + { + if (sequence is not ReadOnlyCollection collection) + { + if (sequence is null) + collection = EmptyReadOnlyCollection.Empty; + else + collection = new List(sequence).AsReadOnly(); + } + + return collection; + } + + private class EmptyReadOnlyCollection + { + public static readonly ReadOnlyCollection Empty = new List().AsReadOnly(); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Collections/DeferredList.cs b/Connected.Expressions/Collections/DeferredList.cs new file mode 100644 index 0000000..0366014 --- /dev/null +++ b/Connected.Expressions/Collections/DeferredList.cs @@ -0,0 +1,206 @@ +using System.Collections; + +namespace Connected.Expressions.Collections; + +internal sealed class DeferredList : IDeferredList, ICollection, IEnumerable, IList, ICollection, IEnumerable, IDeferLoadable +{ + private readonly IEnumerable _source; + + public DeferredList(IEnumerable source) + { + _source = source; + } + private IEnumerable Source => _source; + private List Values { get; set; } + + public void Load() + { + if (!IsLoaded) + Values = new List(Source); + } + + public bool IsLoaded => Values is not null; + + #region IList Members + + public int IndexOf(T item) + { + Load(); + + return Values.IndexOf(item); + } + + public void Insert(int index, T item) + { + Load(); + + Values.Insert(index, item); + } + + public void RemoveAt(int index) + { + Load(); + + Values.RemoveAt(index); + } + + public T this[int index] + { + get + { + Load(); + + return Values[index]; + } + set + { + Load(); + + Values[index] = value; + } + } + + #endregion + + #region ICollection Members + + public void Add(T item) + { + Load(); + + Values.Add(item); + } + + public void Clear() + { + Load(); + + Values.Clear(); + } + + public bool Contains(T item) + { + Load(); + + return Values.Contains(item); + } + + public void CopyTo(T[] array, int arrayIndex) + { + Load(); + + Values.CopyTo(array, arrayIndex); + } + + public int Count + { + get + { + Load(); + + return Values.Count; + } + } + + public bool IsReadOnly => false; + + public bool Remove(T item) + { + Load(); + + return Values.Remove(item); + } + + #endregion + + #region IEnumerable Members + + public IEnumerator GetEnumerator() + { + Load(); + + return Values.GetEnumerator(); + } + + #endregion + + #region IEnumerable Members + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + #endregion + + #region IList Members + + public int Add(object? value) + { + Load(); + + return ((IList)Values).Add(value); + } + + public bool Contains(object? value) + { + Load(); + + return ((IList)Values).Contains(value); + } + + public int IndexOf(object? value) + { + Load(); + + return ((IList)Values).IndexOf(value); + } + + public void Insert(int index, object? value) + { + Load(); + + ((IList)Values).Insert(index, value); + } + + public bool IsFixedSize => false; + + public void Remove(object value) + { + Load(); + + ((IList)Values).Remove(value); + } + + object IList.this[int index] + { + get + { + Load(); + + return ((IList)Values)[index]; + } + set + { + Load(); + + ((IList)Values)[index] = value; + } + } + + #endregion + + #region ICollection Members + + public void CopyTo(Array array, int index) + { + Load(); + + ((IList)Values).CopyTo(array, index); + } + + public bool IsSynchronized => false; + public object? SyncRoot => default; + + #endregion +} \ No newline at end of file diff --git a/Connected.Expressions/Collections/EnumerateOnce.cs b/Connected.Expressions/Collections/EnumerateOnce.cs new file mode 100644 index 0000000..513ddbc --- /dev/null +++ b/Connected.Expressions/Collections/EnumerateOnce.cs @@ -0,0 +1,28 @@ +using System.Collections; + +namespace Connected.Expressions.Collections; + +internal class EnumerateOnce : IEnumerable, IEnumerable +{ + private IEnumerable? _enumerable; + + public EnumerateOnce(IEnumerable enumerable) + { + _enumerable = enumerable; + } + + public IEnumerator GetEnumerator() + { + var en = Interlocked.Exchange(ref _enumerable, null); + + if (en is not null) + return en.GetEnumerator(); + + throw new Exception("Enumerated more than once."); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } +} diff --git a/Connected.Expressions/Collections/IDeferLoadable.cs b/Connected.Expressions/Collections/IDeferLoadable.cs new file mode 100644 index 0000000..93bf7c7 --- /dev/null +++ b/Connected.Expressions/Collections/IDeferLoadable.cs @@ -0,0 +1,7 @@ +namespace Connected.Expressions.Collections; + +internal interface IDeferLoadable +{ + bool IsLoaded { get; } + void Load(); +} diff --git a/Connected.Expressions/Collections/IDeferredList.cs b/Connected.Expressions/Collections/IDeferredList.cs new file mode 100644 index 0000000..36874f1 --- /dev/null +++ b/Connected.Expressions/Collections/IDeferredList.cs @@ -0,0 +1,11 @@ +using System.Collections; + +namespace Connected.Expressions.Collections; + +internal interface IDeferredList : IList, IDeferLoadable +{ +} + +internal interface IDeferredList : IList, IDeferredList +{ +} diff --git a/Connected.Expressions/Collections/ScopedDictionary.cs b/Connected.Expressions/Collections/ScopedDictionary.cs new file mode 100644 index 0000000..d71be68 --- /dev/null +++ b/Connected.Expressions/Collections/ScopedDictionary.cs @@ -0,0 +1,50 @@ +namespace Connected.Expressions.Collections; + +internal sealed class ScopedDictionary + where TKey : notnull +{ + public ScopedDictionary(ScopedDictionary? previous) + { + Previous = previous; + Map = new(); + } + + public ScopedDictionary(ScopedDictionary? previous, IEnumerable> pairs) + : this(previous) + { + foreach (var p in pairs) + Map.Add(p.Key, p.Value); + } + + private ScopedDictionary? Previous { get; } + private Dictionary Map { get; } + + public void Add(TKey key, TValue value) + { + Map.Add(key, value); + } + + public bool TryGetValue(TKey key, out TValue? value) + { + for (var scope = this; scope is not null; scope = scope.Previous) + { + if (scope.Map.TryGetValue(key, out value)) + return true; + } + + value = default; + + return false; + } + + public bool ContainsKey(TKey key) + { + for (var scope = this; scope is not null; scope = scope.Previous) + { + if (scope.Map.ContainsKey(key)) + return true; + } + + return false; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Comparers/DatabaseComparer.cs b/Connected.Expressions/Comparers/DatabaseComparer.cs new file mode 100644 index 0000000..26dd983 --- /dev/null +++ b/Connected.Expressions/Comparers/DatabaseComparer.cs @@ -0,0 +1,323 @@ +using Connected.Expressions.Collections; +using Connected.Expressions.Translation; +using Connected.Expressions.Translation.Resolvers; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions.Comparers; + +internal sealed class DatabaseComparer : ExpressionComparer +{ + protected DatabaseComparer(ScopedDictionary? parameterScope, Func? comparer, + ScopedDictionary? aliasScope) + : base(parameterScope, comparer) + { + AliasScope = aliasScope; + } + + private ScopedDictionary? AliasScope { get; set; } + public new static bool AreEqual(Expression? a, Expression? b) + { + return AreEqual(null, null, a, b, null); + } + + public new static bool AreEqual(Expression? a, Expression? b, Func? fnCompare) + { + return AreEqual(null, null, a, b, fnCompare); + } + + public static bool AreEqual(ScopedDictionary? parameterScope, ScopedDictionary? aliasScope, Expression? a, Expression? b) + { + return new DatabaseComparer(parameterScope, null, aliasScope).Compare(a, b); + } + + public static bool AreEqual(ScopedDictionary? parameterScope, ScopedDictionary? aliasScope, Expression? a, Expression? b, Func? fnCompare) + { + return new DatabaseComparer(parameterScope, fnCompare, aliasScope).Compare(a, b); + } + + protected override bool Compare(Expression? a, Expression? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.NodeType != b.NodeType) + return false; + + if (a.Type != b.Type) + return false; + + return (DatabaseExpressionType)a.NodeType switch + { + DatabaseExpressionType.Table => CompareTable((TableExpression)a, (TableExpression)b), + DatabaseExpressionType.Column => CompareColumn((ColumnExpression)a, (ColumnExpression)b), + DatabaseExpressionType.Select => CompareSelect((SelectExpression)a, (SelectExpression)b), + DatabaseExpressionType.Join => CompareJoin((JoinExpression)a, (JoinExpression)b), + DatabaseExpressionType.Aggregate => CompareAggregate((AggregateExpression)a, (AggregateExpression)b), + DatabaseExpressionType.Scalar or DatabaseExpressionType.Exists or DatabaseExpressionType.In => CompareSubquery((SubqueryExpression)a, (SubqueryExpression)b), + DatabaseExpressionType.AggregateSubquery => CompareAggregateSubquery((AggregateSubqueryExpression)a, (AggregateSubqueryExpression)b), + DatabaseExpressionType.IsNull => CompareIsNull((IsNullExpression)a, (IsNullExpression)b), + DatabaseExpressionType.Between => CompareBetween((BetweenExpression)a, (BetweenExpression)b), + DatabaseExpressionType.RowCount => CompareRowNumber((RowNumberExpression)a, (RowNumberExpression)b), + DatabaseExpressionType.Projection => CompareProjection((ProjectionExpression)a, (ProjectionExpression)b), + DatabaseExpressionType.NamedValue => CompareNamedValue((NamedValueExpression)a, (NamedValueExpression)b), + DatabaseExpressionType.Batch => CompareBatch((BatchExpression)a, (BatchExpression)b), + DatabaseExpressionType.Function => CompareFunction((FunctionExpression)a, (FunctionExpression)b), + DatabaseExpressionType.Entity => CompareEntity((EntityExpression)a, (EntityExpression)b), + DatabaseExpressionType.If => CompareIf((IfCommandExpression)a, (IfCommandExpression)b), + DatabaseExpressionType.Block => CompareBlock((BlockExpression)a, (BlockExpression)b), + _ => base.Compare(a, b), + }; + } + + private static bool CompareTable(TableExpression a, TableExpression b) + { + return a.Name == b.Name; + } + + private bool CompareColumn(ColumnExpression a, ColumnExpression b) + { + return CompareAlias(a.Alias, b.Alias) && a.Name == b.Name; + } + + private bool CompareAlias(Alias a, Alias b) + { + if (AliasScope is not null) + { + if (AliasScope.TryGetValue(a, out Alias? mapped)) + return mapped == b; + } + + return a == b; + } + + private bool CompareSelect(SelectExpression a, SelectExpression b) + { + var save = AliasScope; + + try + { + if (!Compare(a.From, b.From)) + return false; + + AliasScope = new ScopedDictionary(save); + MapAliases(a.From, b.From); + + return Compare(a.Where, b.Where) + && CompareOrderList(a.OrderBy, b.OrderBy) + && CompareExpressionList(a.GroupBy, b.GroupBy) + && Compare(a.Skip, b.Skip) + && Compare(a.Take, b.Take) + && a.IsDistinct == b.IsDistinct + && a.IsReverse == b.IsReverse + && CompareColumnDeclarations(a.Columns, b.Columns); + } + finally + { + AliasScope = save; + } + } + + private void MapAliases(Expression a, Expression b) + { + if (AliasScope is null) + throw new NullReferenceException(nameof(AliasScope)); + + var prodA = DeclaredAliasesResolver.Resolve(a).ToArray(); + var prodB = DeclaredAliasesResolver.Resolve(b).ToArray(); + + for (int i = 0, n = prodA.Length; i < n; i++) + AliasScope.Add(prodA[i], prodB[i]); + } + + private bool CompareOrderList(ReadOnlyCollection? a, ReadOnlyCollection? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + var left = a[i]; + var right = b[i]; + + if (left.OrderType != right.OrderType || !Compare(left.Expression, right.Expression)) + return false; + } + + return true; + } + + private bool CompareColumnDeclarations(ReadOnlyCollection? a, ReadOnlyCollection? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + if (!CompareColumnDeclaration(a[i], b[i])) + return false; + } + + return true; + } + + private bool CompareColumnDeclaration(ColumnDeclaration a, ColumnDeclaration b) + { + return string.Equals(a.Name, b.Name, StringComparison.OrdinalIgnoreCase) && Compare(a.Expression, b.Expression); + } + + private bool CompareJoin(JoinExpression a, JoinExpression b) + { + if (a.Join != b.Join || !Compare(a.Left, b.Left)) + return false; + + if (a.Join == JoinType.CrossApply || a.Join == JoinType.OuterApply) + { + var save = AliasScope; + + try + { + AliasScope = new ScopedDictionary(AliasScope); + MapAliases(a.Left, b.Left); + + return Compare(a.Right, b.Right) && Compare(a.Condition, b.Condition); + } + finally + { + AliasScope = save; + } + } + else + return Compare(a.Right, b.Right) && Compare(a.Condition, b.Condition); + } + + private bool CompareAggregate(AggregateExpression a, AggregateExpression b) + { + return string.Equals(a.AggregateName, b.AggregateName, StringComparison.OrdinalIgnoreCase) && Compare(a.Argument, b.Argument); + } + + private bool CompareIsNull(IsNullExpression a, IsNullExpression b) + { + return Compare(a.Expression, b.Expression); + } + + private bool CompareBetween(BetweenExpression a, BetweenExpression b) + { + return Compare(a.Expression, b.Expression) && Compare(a.Lower, b.Lower) && Compare(a.Upper, b.Upper); + } + + private bool CompareRowNumber(RowNumberExpression a, RowNumberExpression b) + { + return CompareOrderList(a.OrderBy, b.OrderBy); + } + + private bool CompareNamedValue(NamedValueExpression a, NamedValueExpression b) + { + return string.Equals(a.Name, b.Name, StringComparison.OrdinalIgnoreCase) && Compare(a.Value, b.Value); + } + + private bool CompareSubquery(SubqueryExpression a, SubqueryExpression b) + { + if (a.NodeType != b.NodeType) + return false; + + return (DatabaseExpressionType)a.NodeType switch + { + DatabaseExpressionType.Scalar => CompareScalar((ScalarExpression)a, (ScalarExpression)b), + DatabaseExpressionType.Exists => CompareExists((ExistsExpression)a, (ExistsExpression)b), + DatabaseExpressionType.In => CompareIn((InExpression)a, (InExpression)b), + _ => false, + }; + } + + private bool CompareScalar(ScalarExpression a, ScalarExpression b) + { + return Compare(a.Select, b.Select); + } + + private bool CompareExists(ExistsExpression a, ExistsExpression b) + { + return Compare(a.Select, b.Select); + } + + private bool CompareIn(InExpression a, InExpression b) + { + return Compare(a.Expression, b.Expression) && Compare(a.Select, b.Select) && CompareExpressionList(a.Values, b.Values); + } + + private bool CompareAggregateSubquery(AggregateSubqueryExpression a, AggregateSubqueryExpression b) + { + return Compare(a.AggregateAsSubquery, b.AggregateAsSubquery) && Compare(a.AggregateInGroupSelect, b.AggregateInGroupSelect) && a.GroupByAlias == b.GroupByAlias; + } + + private bool CompareProjection(ProjectionExpression a, ProjectionExpression b) + { + if (!Compare(a.Select, b.Select)) + return false; + + var save = AliasScope; + + try + { + AliasScope = new ScopedDictionary(AliasScope); + AliasScope.Add(a.Select.Alias, b.Select.Alias); + + return Compare(a.Projector, b.Projector) + && Compare(a.Aggregator, b.Aggregator) + && a.IsSingleton == b.IsSingleton; + } + finally + { + AliasScope = save; + } + } + + private bool CompareBatch(BatchExpression x, BatchExpression y) + { + return Compare(x.Input, y.Input) && Compare(x.Operation, y.Operation) && Compare(x.BatchSize, y.BatchSize) && Compare(x.Stream, y.Stream); + } + + private bool CompareIf(IfCommandExpression x, IfCommandExpression y) + { + return Compare(x.Check, y.Check) && Compare(x.IfTrue, y.IfTrue) && Compare(x.IfFalse, y.IfFalse); + } + + private bool CompareBlock(BlockExpression x, BlockExpression y) + { + if (x.Commands.Count != y.Commands.Count) + return false; + + for (var i = 0; i < x.Commands.Count; i++) + { + if (!Compare(x.Commands[i], y.Commands[i])) + return false; + } + + return true; + } + + private bool CompareFunction(FunctionExpression x, FunctionExpression y) + { + return string.Equals(x.Name, y.Name, StringComparison.OrdinalIgnoreCase) && CompareExpressionList(x.Arguments, y.Arguments); + } + + private bool CompareEntity(EntityExpression x, EntityExpression y) + { + return x.EntityType == y.EntityType && Compare(x.Expression, y.Expression); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Comparers/ExpressionComparer.cs b/Connected.Expressions/Comparers/ExpressionComparer.cs new file mode 100644 index 0000000..3848b5d --- /dev/null +++ b/Connected.Expressions/Comparers/ExpressionComparer.cs @@ -0,0 +1,317 @@ +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using Connected.Expressions.Collections; + +namespace Connected.Expressions.Comparers; + +internal class ExpressionComparer +{ + protected ExpressionComparer(ScopedDictionary? parameterScope, Func? comparer) + { + ParameterScope = parameterScope; + Comparer = comparer; + } + + protected Func? Comparer { get; } + private ScopedDictionary? ParameterScope { get; set; } + + public static bool AreEqual(Expression a, Expression b) + { + return AreEqual(null, a, b); + } + + public static bool AreEqual(Expression a, Expression b, Func? fnCompare) + { + return AreEqual(null, a, b, fnCompare); + } + + public static bool AreEqual(ScopedDictionary? parameterScope, Expression a, Expression b) + { + return new ExpressionComparer(parameterScope, null).Compare(a, b); + } + + public static bool AreEqual(ScopedDictionary? parameterScope, Expression a, Expression b, Func? fnCompare) + { + return new ExpressionComparer(parameterScope, fnCompare).Compare(a, b); + } + + protected virtual bool Compare(Expression? a, Expression? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.NodeType != b.NodeType) + return false; + + if (a.Type != b.Type) + return false; + + return a.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Not or ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.ArrayLength or ExpressionType.Quote or ExpressionType.TypeAs or ExpressionType.UnaryPlus => CompareUnary((UnaryExpression)a, (UnaryExpression)b), + ExpressionType.Add or ExpressionType.AddChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked or ExpressionType.Multiply or ExpressionType.MultiplyChecked or ExpressionType.Divide or ExpressionType.Modulo or ExpressionType.And or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse or ExpressionType.LessThan or ExpressionType.LessThanOrEqual or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.Equal or ExpressionType.NotEqual or ExpressionType.Coalesce or ExpressionType.ArrayIndex or ExpressionType.RightShift or ExpressionType.LeftShift or ExpressionType.ExclusiveOr or ExpressionType.Power => CompareBinary((BinaryExpression)a, (BinaryExpression)b), + ExpressionType.TypeIs => CompareTypeIs((TypeBinaryExpression)a, (TypeBinaryExpression)b), + ExpressionType.Conditional => CompareConditional((ConditionalExpression)a, (ConditionalExpression)b), + ExpressionType.Constant => CompareConstant((ConstantExpression)a, (ConstantExpression)b), + ExpressionType.Parameter => CompareParameter((ParameterExpression)a, (ParameterExpression)b), + ExpressionType.MemberAccess => CompareMemberAccess((MemberExpression)a, (MemberExpression)b), + ExpressionType.Call => CompareMethodCall((MethodCallExpression)a, (MethodCallExpression)b), + ExpressionType.Lambda => CompareLambda((LambdaExpression)a, (LambdaExpression)b), + ExpressionType.New => CompareNew((NewExpression)a, (NewExpression)b), + ExpressionType.NewArrayInit or ExpressionType.NewArrayBounds => CompareNewArray((NewArrayExpression)a, (NewArrayExpression)b), + ExpressionType.Invoke => CompareInvocation((InvocationExpression)a, (InvocationExpression)b), + ExpressionType.MemberInit => CompareMemberInit((MemberInitExpression)a, (MemberInitExpression)b), + ExpressionType.ListInit => CompareListInit((ListInitExpression)a, (ListInitExpression)b), + _ => throw new NotSupportedException($"Unhandled expression type: '{a.NodeType}'") + }; + } + + protected virtual bool CompareUnary(UnaryExpression a, UnaryExpression b) + { + return a.NodeType == b.NodeType + && a.Method == b.Method + && a.IsLifted == b.IsLifted + && a.IsLiftedToNull == b.IsLiftedToNull + && Compare(a.Operand, b.Operand); + } + + protected virtual bool CompareBinary(BinaryExpression a, BinaryExpression b) + { + return a.NodeType == b.NodeType + && a.Method == b.Method + && a.IsLifted == b.IsLifted + && a.IsLiftedToNull == b.IsLiftedToNull + && Compare(a.Left, b.Left) + && Compare(a.Right, b.Right); + } + + protected virtual bool CompareTypeIs(TypeBinaryExpression a, TypeBinaryExpression b) + { + return a.TypeOperand == b.TypeOperand + && Compare(a.Expression, b.Expression); + } + + protected virtual bool CompareConditional(ConditionalExpression a, ConditionalExpression b) + { + return Compare(a.Test, b.Test) + && Compare(a.IfTrue, b.IfTrue) + && Compare(a.IfFalse, b.IfFalse); + } + + protected virtual bool CompareConstant(ConstantExpression a, ConstantExpression b) + { + if (Comparer is not null) + return Comparer(a.Value, b.Value); + else + return Equals(a.Value, b.Value); + } + + protected virtual bool CompareParameter(ParameterExpression a, ParameterExpression b) + { + if (ParameterScope is not null) + { + if (ParameterScope.TryGetValue(a, out ParameterExpression? mapped)) + return mapped == b; + } + + return a == b; + } + + protected virtual bool CompareMemberAccess(MemberExpression a, MemberExpression b) + { + return a.Member == b.Member + && Compare(a.Expression, b.Expression); + } + + protected virtual bool CompareMethodCall(MethodCallExpression a, MethodCallExpression b) + { + return a.Method == b.Method + && Compare(a.Object, b.Object) + && CompareExpressionList(a.Arguments, b.Arguments); + } + + protected virtual bool CompareLambda(LambdaExpression a, LambdaExpression b) + { + var n = a.Parameters.Count; + + if (b.Parameters.Count != n) + return false; + + for (var i = 0; i < n; i++) + { + if (a.Parameters[i].Type != b.Parameters[i].Type) + return false; + } + + var save = ParameterScope; + + ParameterScope = new ScopedDictionary(null); + + try + { + for (var i = 0; i < n; i++) + ParameterScope.Add(a.Parameters[i], b.Parameters[i]); + + return Compare(a.Body, b.Body); + } + finally + { + ParameterScope = save; + } + } + + protected virtual bool CompareNew(NewExpression a, NewExpression b) + { + return a.Constructor == b.Constructor + && CompareExpressionList(a.Arguments, b.Arguments) + && CompareMemberList(a.Members, b.Members); + } + + protected virtual bool CompareExpressionList(ReadOnlyCollection? a, ReadOnlyCollection? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + if (!Compare(a[i], b[i])) + return false; + } + + return true; + } + + protected virtual bool CompareMemberList(ReadOnlyCollection? a, ReadOnlyCollection? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + if (a[i] != b[i]) + return false; + } + + return true; + } + + protected virtual bool CompareNewArray(NewArrayExpression a, NewArrayExpression b) + { + return CompareExpressionList(a.Expressions, b.Expressions); + } + + protected virtual bool CompareInvocation(InvocationExpression a, InvocationExpression b) + { + return Compare(a.Expression, b.Expression) && CompareExpressionList(a.Arguments, b.Arguments); + } + + protected virtual bool CompareMemberInit(MemberInitExpression a, MemberInitExpression b) + { + return Compare(a.NewExpression, b.NewExpression) && CompareBindingList(a.Bindings, b.Bindings); + } + + protected virtual bool CompareBindingList(ReadOnlyCollection a, ReadOnlyCollection b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + if (!CompareBinding(a[i], b[i])) + return false; + } + + return true; + } + + protected virtual bool CompareBinding(MemberBinding a, MemberBinding b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.BindingType != b.BindingType) + return false; + + if (a.Member != b.Member) + return false; + + return a.BindingType switch + { + MemberBindingType.Assignment => CompareMemberAssignment((MemberAssignment)a, (MemberAssignment)b), + MemberBindingType.ListBinding => CompareMemberListBinding((MemberListBinding)a, (MemberListBinding)b), + MemberBindingType.MemberBinding => CompareMemberMemberBinding((MemberMemberBinding)a, (MemberMemberBinding)b), + _ => throw new NotSupportedException($"Unhandled binding type: '{a.BindingType}'") + }; + } + + protected virtual bool CompareMemberAssignment(MemberAssignment a, MemberAssignment b) + { + return a.Member == b.Member && Compare(a.Expression, b.Expression); + } + + protected virtual bool CompareMemberListBinding(MemberListBinding a, MemberListBinding b) + { + return a.Member == b.Member && CompareElementInitList(a.Initializers, b.Initializers); + } + + protected virtual bool CompareMemberMemberBinding(MemberMemberBinding a, MemberMemberBinding b) + { + return a.Member == b.Member && CompareBindingList(a.Bindings, b.Bindings); + } + + protected virtual bool CompareListInit(ListInitExpression a, ListInitExpression b) + { + return Compare(a.NewExpression, b.NewExpression) && CompareElementInitList(a.Initializers, b.Initializers); + } + + protected virtual bool CompareElementInitList(ReadOnlyCollection? a, ReadOnlyCollection? b) + { + if (a == b) + return true; + + if (a is null || b is null) + return false; + + if (a.Count != b.Count) + return false; + + for (var i = 0; i < a.Count; i++) + { + if (!CompareElementInit(a[i], b[i])) + return false; + } + + return true; + } + + protected virtual bool CompareElementInit(ElementInit a, ElementInit b) + { + return a.AddMethod == b.AddMethod + && CompareExpressionList(a.Arguments, b.Arguments); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Connected.Expressions.csproj b/Connected.Expressions/Connected.Expressions.csproj new file mode 100644 index 0000000..76502af --- /dev/null +++ b/Connected.Expressions/Connected.Expressions.csproj @@ -0,0 +1,15 @@ + + + + net7.0 + enable + enable + + + + + + + + + diff --git a/Connected.Expressions/Evaluation/ColumnNominator.cs b/Connected.Expressions/Evaluation/ColumnNominator.cs new file mode 100644 index 0000000..e52b6d4 --- /dev/null +++ b/Connected.Expressions/Evaluation/ColumnNominator.cs @@ -0,0 +1,49 @@ +using System.Linq.Expressions; +using Connected.Expressions.Mappings; +using ExpressionVisitor = Connected.Expressions.Visitors.ExpressionVisitor; + +namespace Connected.Expressions.Evaluation; + +internal sealed class ColumnNominator : ExpressionVisitor +{ + private ColumnNominator() + { + Candidates = new HashSet(); + } + + private HashSet Candidates { get; set; } + private bool CannotBeEvaluated { get; set; } + + public static HashSet Nominate(Expression expression) + { + var nominator = new ColumnNominator(); + + nominator.Visit(expression); + + return nominator.Candidates; + } + + protected override Expression? Visit(Expression? expression) + { + if (expression is not null) + { + var saveCannotBeEvaluated = CannotBeEvaluated; + + CannotBeEvaluated = false; + + base.Visit(expression); + + if (!CannotBeEvaluated) + { + if (MappingsCache.CanEvaluateLocally(expression)) + Candidates.Add(expression); + else + CannotBeEvaluated = true; + } + + CannotBeEvaluated |= saveCannotBeEvaluated; + } + + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Evaluation/Command.cs b/Connected.Expressions/Evaluation/Command.cs new file mode 100644 index 0000000..cdc259c --- /dev/null +++ b/Connected.Expressions/Evaluation/Command.cs @@ -0,0 +1,16 @@ +using System.Collections.ObjectModel; +using Connected.Expressions.Collections; + +namespace Connected.Expressions.Evaluation; + +internal sealed class Command +{ + public Command(string commandText, IEnumerable parameters) + { + CommandText = commandText; + Parameters = parameters.ToReadOnly(); + } + + public string CommandText { get; } + public ReadOnlyCollection Parameters { get; } +} diff --git a/Connected.Expressions/Evaluation/CommandParameter.cs b/Connected.Expressions/Evaluation/CommandParameter.cs new file mode 100644 index 0000000..81d7f44 --- /dev/null +++ b/Connected.Expressions/Evaluation/CommandParameter.cs @@ -0,0 +1,16 @@ +using Connected.Expressions.Languages; + +namespace Connected.Expressions.Evaluation; +internal sealed class CommandParameter +{ + public CommandParameter(string name, Type type, DataType dataType) + { + Name = name; + Type = type; + DataType = dataType; + } + + public string Name { get; } + public Type Type { get; } + public DataType DataType { get; } +} diff --git a/Connected.Expressions/Evaluation/ExecutionBuilder.cs b/Connected.Expressions/Evaluation/ExecutionBuilder.cs new file mode 100644 index 0000000..66e268c --- /dev/null +++ b/Connected.Expressions/Evaluation/ExecutionBuilder.cs @@ -0,0 +1,86 @@ +using Connected.Entities.Storage; +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Visitors; +using Connected.Interop; +using System.Data; +using System.Linq.Expressions; + +namespace Connected.Expressions.Evaluation; + +public sealed class ExecutionBuilder : DatabaseVisitor +{ + private ExecutionBuilder(ExpressionCompilationContext context, Linguist linguist, Expression executor) + { + Context = context; + Linguist = linguist; + Executor = executor; + } + + private ExpressionCompilationContext Context { get; } + private Linguist Linguist { get; } + private Expression Executor { get; } + + public static Expression> Build(ExpressionCompilationContext context, Linguist linguist, Expression expression, Expression provider) + { + var executor = Expression.Parameter(typeof(IStorageExecutor), "executor"); + var builder = new ExecutionBuilder(context, linguist, executor); + + return builder.Build(expression); + } + + private Expression> Build(Expression expression) + { + expression = Visit(expression); + expression = Expression.Lambda>(expression, (ParameterExpression)Executor); + + return (Expression>)expression; + } + + protected override Expression VisitProjection(ProjectionExpression projection) + { + /* + * parameterize query + */ + var commandText = Linguist.Format(projection.Select); + var command = new StorageOperation + { + CommandText = commandText + }; + + foreach (var parameter in Context.Parameters) + { + command.AddParameter(new StorageParameter + { + Direction = ParameterDirection.Input, + Name = $"@{parameter.Key}", + Type = parameter.Value.Type.ToDbType(), + Value = parameter.Value.Value + }); + }; + + return ExecuteProjection(projection, command); + } + + private Expression ExecuteProjection(ProjectionExpression projection, IStorageOperation operation) + { + var method = nameof(IStorageExecutor.Execute); + /* + * call low-level execute directly on supplied DbQueryProvider + */ + Type typeArgument; + + if (projection.Type.IsEnumerable()) + typeArgument = projection.Type.GetGenericArguments()[0]; + else + typeArgument = projection.Type; + + var constant = Expression.Constant(operation); + Expression body = Expression.Call(Executor, nameof(IStorageExecutor.Execute), new Type[] { typeArgument }, constant); + + if (projection.Aggregator is not null)// apply aggregator + body = ExpressionReplacer.Replace(projection.Aggregator.Body, projection.Aggregator.Parameters[0], body); + + return body; + } +} diff --git a/Connected.Expressions/Evaluation/ExpressionNominator.cs b/Connected.Expressions/Evaluation/ExpressionNominator.cs new file mode 100644 index 0000000..1a7d70d --- /dev/null +++ b/Connected.Expressions/Evaluation/ExpressionNominator.cs @@ -0,0 +1,70 @@ + +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation.Projections; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Evaluation; + +internal sealed class ExpressionNominator : DatabaseVisitor +{ + private ExpressionNominator(QueryLanguage language, ProjectionAffinity affinity) + { + Language = language; + Affinity = affinity; + + Candidates = new HashSet(); + } + + private QueryLanguage Language { get; } + private ProjectionAffinity Affinity { get; } + + private HashSet Candidates { get; set; } + private bool IsBlocked { get; set; } + + public static HashSet Nominate(QueryLanguage language, ProjectionAffinity affinity, Expression expression) + { + var nominator = new ExpressionNominator(language, affinity); + + nominator.Visit(expression); + + return nominator.Candidates; + } + + protected override Expression? Visit(Expression? expression) + { + if (expression is not null) + { + var saveIsBlocked = IsBlocked; + + IsBlocked = false; + + if (Language.MustBeColumn(expression)) + Candidates.Add(expression); + else + { + base.Visit(expression); + + if (!IsBlocked) + { + if (Language.MustBeColumn(expression) || (Affinity == ProjectionAffinity.Server && Language.CanBeColumn(expression))) + Candidates.Add(expression); + else + IsBlocked = true; + } + + IsBlocked |= saveIsBlocked; + } + } + + return expression; + } + + protected override Expression VisitProjection(ProjectionExpression expression) + { + Visit(expression.Projector); + + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Evaluation/ExpressionReplacer.cs b/Connected.Expressions/Evaluation/ExpressionReplacer.cs new file mode 100644 index 0000000..492d4aa --- /dev/null +++ b/Connected.Expressions/Evaluation/ExpressionReplacer.cs @@ -0,0 +1,36 @@ +using System.Linq.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Evaluation; +internal sealed class ExpressionReplacer : DatabaseVisitor +{ + private readonly Expression _searchFor; + private readonly Expression _replaceWith; + + private ExpressionReplacer(Expression searchFor, Expression replaceWith) + { + _searchFor = searchFor; + _replaceWith = replaceWith; + } + + public static Expression Replace(Expression expression, Expression searchFor, Expression replaceWith) + { + return new ExpressionReplacer(searchFor, replaceWith).Visit(expression); + } + + public static Expression ReplaceAll(Expression expression, Expression[] searchFor, Expression[] replaceWith) + { + for (int i = 0, n = searchFor.Length; i < n; i++) + expression = Replace(expression, searchFor[i], replaceWith[i]); + + return expression; + } + + protected override Expression? Visit(Expression? exp) + { + if (exp == _searchFor) + return _replaceWith; + + return base.Visit(exp); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Evaluation/IStorageExecutor.cs b/Connected.Expressions/Evaluation/IStorageExecutor.cs new file mode 100644 index 0000000..5d3a8e2 --- /dev/null +++ b/Connected.Expressions/Evaluation/IStorageExecutor.cs @@ -0,0 +1,9 @@ +using Connected.Entities; +using Connected.Entities.Storage; + +namespace Connected.Expressions.Evaluation; +public interface IStorageExecutor +{ + IEnumerable Execute(IStorageOperation operation) + where TResult : IEntity; +} diff --git a/Connected.Expressions/Evaluation/PartialEvaluator.cs b/Connected.Expressions/Evaluation/PartialEvaluator.cs new file mode 100644 index 0000000..4f637c5 --- /dev/null +++ b/Connected.Expressions/Evaluation/PartialEvaluator.cs @@ -0,0 +1,16 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions.Evaluation; + +internal sealed class PartialEvaluator +{ + public static Expression Eval(ExpressionCompilationContext context, Expression expression) + { + return Eval(context, expression, null); + } + + public static Expression Eval(ExpressionCompilationContext context, Expression expression, Func? fnPostEval) + { + return SubtreeEvaluator.Eval(context, ColumnNominator.Nominate(expression), fnPostEval, expression); + } +} diff --git a/Connected.Expressions/Evaluation/SubtreeEvaluator.cs b/Connected.Expressions/Evaluation/SubtreeEvaluator.cs new file mode 100644 index 0000000..a87e5c1 --- /dev/null +++ b/Connected.Expressions/Evaluation/SubtreeEvaluator.cs @@ -0,0 +1,108 @@ +using Connected.Expressions.Reflection; +using Connected.Interop; +using System.Linq.Expressions; +using System.Reflection; +using ExpressionVisitor = Connected.Expressions.Visitors.ExpressionVisitor; + +namespace Connected.Expressions.Evaluation; + +internal sealed class SubtreeEvaluator : ExpressionVisitor +{ + private SubtreeEvaluator(ExpressionCompilationContext context, HashSet candidates, Func? onEval) + { + Candidates = candidates; + OnEval = onEval; + Context = context; + } + + public ExpressionCompilationContext Context { get; } + private HashSet Candidates { get; set; } + private Func? OnEval { get; set; } + + internal static Expression Eval(ExpressionCompilationContext context, HashSet candidates, Func? onEval, Expression exp) + { + if (new SubtreeEvaluator(context, candidates, onEval).Visit(exp) is not Expression subtreeExpression) + throw new NullReferenceException(nameof(subtreeExpression)); + + return subtreeExpression; + } + + protected override Expression? Visit(Expression? exp) + { + if (exp is null) + return null; + + if (Candidates.Contains(exp)) + return Evaluate(exp); + + return base.Visit(exp); + } + + protected override Expression VisitConditional(ConditionalExpression c) + { + if (Candidates.Contains(c.Test)) + { + var test = Evaluate(c.Test); + + if (test is ConstantExpression && ((ConstantExpression)test).Type == typeof(bool)) + { + if ((bool)((ConstantExpression)test).Value) + return Visit(c.IfTrue); + else + return Visit(c.IfFalse); + } + } + + return base.VisitConditional(c); + } + + private Expression PostEval(ConstantExpression e) + { + if (OnEval is not null) + return OnEval(e); + + return e; + } + + private Expression Evaluate(Expression e) + { + var type = e.Type; + + if (e.NodeType == ExpressionType.Convert) + { + var u = (UnaryExpression)e; + + if (Nullables.GetNonNullableType(u.Operand.Type) == Nullables.GetNonNullableType(type)) + e = ((UnaryExpression)e).Operand; + } + + if (e.NodeType == ExpressionType.Constant) + { + if (e.Type == type) + return e; + else if (Nullables.GetNonNullableType(e.Type) == Nullables.GetNonNullableType(type)) + return Expression.Constant(((ConstantExpression)e).Value, type); + } + + if (e is MemberExpression me) + { + if (me.Expression is ConstantExpression ce) + { + var value = me.Member.GetValue(ce.Value); + var constant = Expression.Constant(value, type); + + Context.Parameters.Add(me.Member.Name, constant); + + return PostEval(constant); + } + } + + if (type.GetTypeInfo().IsValueType) + e = Expression.Convert(e, typeof(object)); + + var lambda = Expression.Lambda>(e); + var fn = lambda.Compile(); + + return PostEval(Expression.Constant(fn(), type)); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Evaluation/SubtreeResolver.cs b/Connected.Expressions/Evaluation/SubtreeResolver.cs new file mode 100644 index 0000000..7e44e57 --- /dev/null +++ b/Connected.Expressions/Evaluation/SubtreeResolver.cs @@ -0,0 +1,35 @@ +using System.Linq.Expressions; +using System.Reflection; +using ExpressionVisitor = Connected.Expressions.Visitors.ExpressionVisitor; + +namespace Connected.Expressions.Evaluation; + +public sealed class SubtreeResolver : ExpressionVisitor +{ + private SubtreeResolver(Type type) + { + Type = type; + } + + private Type Type { get; } + private Expression Found { get; set; } + + public static Expression? Resolve(Expression expression, Type type) + { + var finder = new SubtreeResolver(type); + + finder.Visit(expression); + + return finder.Found; + } + + protected override Expression? Visit(Expression? exp) + { + var node = base.Visit(exp); + + if (Found is null && node is not null && Type.GetTypeInfo().IsAssignableFrom(node.Type.GetTypeInfo())) + Found = node; + + return node; + } +} \ No newline at end of file diff --git a/Connected.Expressions/ExpressionCompilationContext.cs b/Connected.Expressions/ExpressionCompilationContext.cs new file mode 100644 index 0000000..d44ec1e --- /dev/null +++ b/Connected.Expressions/ExpressionCompilationContext.cs @@ -0,0 +1,19 @@ +using Connected.Expressions.Languages; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class ExpressionCompilationContext +{ + public ExpressionCompilationContext(QueryLanguage language) + { + Language = language; + + Parameters = new(); + + } + + public QueryLanguage Language { get; } + + public Dictionary Parameters { get; } +} diff --git a/Connected.Expressions/Expressions/AggregateExpression.cs b/Connected.Expressions/Expressions/AggregateExpression.cs new file mode 100644 index 0000000..dce1e45 --- /dev/null +++ b/Connected.Expressions/Expressions/AggregateExpression.cs @@ -0,0 +1,18 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class AggregateExpression : DatabaseExpression +{ + public AggregateExpression(Type type, string aggregateName, Expression argument, bool isDistinct) + : base(DatabaseExpressionType.Aggregate, type) + { + AggregateName = aggregateName; + Argument = argument; + IsDistinct = isDistinct; + } + + public string AggregateName { get; } + public Expression Argument { get; } + public bool IsDistinct { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/AggregateSubqueryExpression.cs b/Connected.Expressions/Expressions/AggregateSubqueryExpression.cs new file mode 100644 index 0000000..c045d92 --- /dev/null +++ b/Connected.Expressions/Expressions/AggregateSubqueryExpression.cs @@ -0,0 +1,19 @@ +using Connected.Expressions.Translation; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class AggregateSubqueryExpression : DatabaseExpression +{ + public AggregateSubqueryExpression(Alias groupByAlias, Expression aggregateInGroupSelect, ScalarExpression aggregateAsSubquery) + : base(DatabaseExpressionType.AggregateSubquery, aggregateAsSubquery.Type) + { + AggregateInGroupSelect = aggregateInGroupSelect; + GroupByAlias = groupByAlias; + AggregateAsSubquery = aggregateAsSubquery; + } + + public Alias GroupByAlias { get; } + public Expression AggregateInGroupSelect { get; } + public ScalarExpression AggregateAsSubquery { get; } +} diff --git a/Connected.Expressions/Expressions/AliasedExpression.cs b/Connected.Expressions/Expressions/AliasedExpression.cs new file mode 100644 index 0000000..65615ce --- /dev/null +++ b/Connected.Expressions/Expressions/AliasedExpression.cs @@ -0,0 +1,14 @@ +using Connected.Expressions.Translation; + +namespace Connected.Expressions; + +public abstract class AliasedExpression : DatabaseExpression +{ + protected AliasedExpression(DatabaseExpressionType nodeType, Type type, Alias alias) + : base(nodeType, type) + { + Alias = alias; + } + + public Alias Alias { get; } +} diff --git a/Connected.Expressions/Expressions/BatchExpression.cs b/Connected.Expressions/Expressions/BatchExpression.cs new file mode 100644 index 0000000..242e87f --- /dev/null +++ b/Connected.Expressions/Expressions/BatchExpression.cs @@ -0,0 +1,23 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class BatchExpression : Expression +{ + public BatchExpression(Expression input, LambdaExpression operation, Expression batchSize, Expression stream) + { + Input = input; + Operation = operation; + BatchSize = batchSize; + Stream = stream; + Type = typeof(IEnumerable<>).MakeGenericType(operation.Body.Type); + } + + public override Type Type { get; } + public Expression Input { get; } + public LambdaExpression Operation { get; } + public Expression BatchSize { get; } + public Expression Stream { get; } + + public override ExpressionType NodeType => (ExpressionType)DatabaseExpressionType.Batch; +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/BetweenExpression.cs b/Connected.Expressions/Expressions/BetweenExpression.cs new file mode 100644 index 0000000..5557b18 --- /dev/null +++ b/Connected.Expressions/Expressions/BetweenExpression.cs @@ -0,0 +1,18 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class BetweenExpression : DatabaseExpression +{ + public BetweenExpression(Expression expression, Expression lower, Expression upper) + : base(DatabaseExpressionType.Between, expression.Type) + { + Expression = expression; + Lower = lower; + Upper = upper; + } + + public Expression Expression { get; } + public Expression Lower { get; } + public Expression Upper { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/BlockExpression.cs b/Connected.Expressions/Expressions/BlockExpression.cs new file mode 100644 index 0000000..5de7c3a --- /dev/null +++ b/Connected.Expressions/Expressions/BlockExpression.cs @@ -0,0 +1,21 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class BlockExpression : CommandExpression +{ + public BlockExpression(IList commands) + : base(DatabaseExpressionType.Block, commands[commands.Count - 1].Type) + { + Commands = commands.ToReadOnly(); + } + + public BlockExpression(params Expression[] commands) + : this((IList)commands) + { + } + + public ReadOnlyCollection Commands { get; } +} diff --git a/Connected.Expressions/Expressions/ClientJoinExpression.cs b/Connected.Expressions/Expressions/ClientJoinExpression.cs new file mode 100644 index 0000000..a8001ee --- /dev/null +++ b/Connected.Expressions/Expressions/ClientJoinExpression.cs @@ -0,0 +1,20 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class ClientJoinExpression : DatabaseExpression +{ + public ClientJoinExpression(ProjectionExpression projection, IEnumerable outerKey, IEnumerable innerKey) + : base(DatabaseExpressionType.ClientJoin, projection.Type) + { + OuterKey = outerKey.ToReadOnly(); + InnerKey = innerKey.ToReadOnly(); + Projection = projection; + } + + public ReadOnlyCollection OuterKey { get; } + public ReadOnlyCollection InnerKey { get; } + public ProjectionExpression Projection { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/ColumnAssignment.cs b/Connected.Expressions/Expressions/ColumnAssignment.cs new file mode 100644 index 0000000..e7d402e --- /dev/null +++ b/Connected.Expressions/Expressions/ColumnAssignment.cs @@ -0,0 +1,15 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; +public sealed class ColumnAssignment +{ + public ColumnAssignment(ColumnExpression column, Expression expression) + { + Column = column; + Expression = expression; + } + + public ColumnExpression Column { get; } + public Expression Expression { get; } +} + diff --git a/Connected.Expressions/Expressions/ColumnExpression.cs b/Connected.Expressions/Expressions/ColumnExpression.cs new file mode 100644 index 0000000..a0c2d1d --- /dev/null +++ b/Connected.Expressions/Expressions/ColumnExpression.cs @@ -0,0 +1,45 @@ +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; + +namespace Connected.Expressions; + +public sealed class ColumnExpression : DatabaseExpression, IEquatable +{ + public ColumnExpression(Type type, DataType dataType, Alias alias, string name) + : base(DatabaseExpressionType.Column, type) + { + if (dataType is null) + throw new ArgumentNullException(nameof(dataType)); + + if (name is null) + throw new ArgumentNullException(nameof(name)); + + Alias = alias; + Name = name; + QueryType = dataType; + } + + public Alias Alias { get; } + public string Name { get; } + public DataType QueryType { get; } + + public override string ToString() + { + return $"{Alias}.C({Name})"; + } + + public override int GetHashCode() + { + return Alias.GetHashCode() + Name.GetHashCode(); + } + + public override bool Equals(object? obj) + { + return Equals(obj as ColumnExpression); + } + + public bool Equals(ColumnExpression? other) + { + return other is not null && (this) == other || (Alias == other?.Alias && Name == other.Name); + } +} diff --git a/Connected.Expressions/Expressions/CommandExpression.cs b/Connected.Expressions/Expressions/CommandExpression.cs new file mode 100644 index 0000000..f46dcd4 --- /dev/null +++ b/Connected.Expressions/Expressions/CommandExpression.cs @@ -0,0 +1,9 @@ +namespace Connected.Expressions; + +public abstract class CommandExpression : DatabaseExpression +{ + protected CommandExpression(DatabaseExpressionType eType, Type type) + : base(eType, type) + { + } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/DatabaseExpression.cs b/Connected.Expressions/Expressions/DatabaseExpression.cs new file mode 100644 index 0000000..4df86a2 --- /dev/null +++ b/Connected.Expressions/Expressions/DatabaseExpression.cs @@ -0,0 +1,56 @@ +using Connected.Expressions.Serialization; +using System.Diagnostics; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public enum DatabaseExpressionType +{ + Table = 1000, + ClientJoin = 1001, + Column = 1002, + Select = 1003, + Projection = 1004, + Entity = 1005, + Join = 1006, + Aggregate = 1007, + Scalar = 1008, + Exists = 1009, + In = 1010, + Grouping = 1011, + AggregateSubquery = 1012, + IsNull = 1013, + Between = 1014, + RowCount = 1015, + NamedValue = 1016, + OuterJoined = 1017, + Batch = 1018, + Function = 1019, + Block = 1020, + If = 1021, + Declaration = 1022, + Variable = 1023 +} + +[DebuggerDisplay("{DebugText}")] +public abstract class DatabaseExpression : Expression +{ + private readonly Type _type; + + protected DatabaseExpression(DatabaseExpressionType expressionType, Type type) + { + ExpressionType = expressionType; + _type = type; + } + public DatabaseExpressionType ExpressionType { get; } + + public override ExpressionType NodeType => (ExpressionType)(int)ExpressionType; + public override Type Type => _type; + + private string DebugText => $"{GetType().Name}: {DatabaseExpressionExtensions.ResolveNodeTypeName(this)} := {this}"; + + public override string ToString() + { + return DatabaseSerializer.Serialize(this); + } +} diff --git a/Connected.Expressions/Expressions/DatabaseExpressionExtensions.cs b/Connected.Expressions/Expressions/DatabaseExpressionExtensions.cs new file mode 100644 index 0000000..70de9a4 --- /dev/null +++ b/Connected.Expressions/Expressions/DatabaseExpressionExtensions.cs @@ -0,0 +1,209 @@ +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +internal static class DatabaseExpressionExtensions +{ + public static bool IsDatabaseExpression(this ExpressionType expressionType) + { + return ((int)expressionType) >= 1000; + } + + public static string ResolveNodeTypeName(this Expression expression) + { + if (expression is DatabaseExpression d) + return d.ExpressionType.ToString(); + else + return expression.NodeType.ToString(); + } + + public static SelectExpression SetColumns(this SelectExpression select, IEnumerable columns) + { + return new SelectExpression(select.Alias, columns.OrderBy(c => c.Name), select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, + select.Skip, select.Take, select.IsReverse); + } + + public static SelectExpression AddColumn(this SelectExpression select, ColumnDeclaration column) + { + var columns = new List(select.Columns) + { + column + }; + + return select.SetColumns(columns); + } + + public static SelectExpression RemoveColumn(this SelectExpression select, ColumnDeclaration column) + { + var columns = new List(select.Columns); + + columns.Remove(column); + + return select.SetColumns(columns); + } + + public static string ResolveAvailableColumnName(this IList columns, string baseName) + { + var name = baseName; + var n = 0; + + while (!IsUniqueName(columns, name)) + name = baseName + (n++); + + return name; + } + + private static bool IsUniqueName(IList columns, string name) + { + foreach (var col in columns) + { + if (string.Equals(col.Name, name, StringComparison.OrdinalIgnoreCase)) + return false; + } + + return true; + } + + public static ProjectionExpression AddOuterJoinTest(this ProjectionExpression proj, QueryLanguage language, Expression expression) + { + var colName = ResolveAvailableColumnName(proj.Select.Columns, "Test"); + var colType = language.TypeSystem.ResolveColumnType(expression.Type); + var newSource = proj.Select.AddColumn(new ColumnDeclaration(colName, expression, colType)); + var newProjector = new OuterJoinedExpression(new ColumnExpression(expression.Type, colType, newSource.Alias, colName), proj.Projector); + + return new ProjectionExpression(newSource, newProjector, proj.Aggregator); + } + + public static SelectExpression SetDistinct(this SelectExpression select, bool isDistinct) + { + if (select.IsDistinct != isDistinct) + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, select.OrderBy, select.GroupBy, isDistinct, select.Skip, select.Take, select.IsReverse); + + return select; + } + + public static SelectExpression SetReverse(this SelectExpression select, bool isReverse) + { + if (select.IsReverse != isReverse) + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, isReverse); + + return select; + } + + public static SelectExpression SetWhere(this SelectExpression select, Expression? where) + { + if (where != select.Where) + return new SelectExpression(select.Alias, select.Columns, select.From, where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); + + return select; + } + + public static SelectExpression SetOrderBy(this SelectExpression select, IEnumerable orderBy) + { + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, orderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); + } + + public static SelectExpression AddOrderExpression(this SelectExpression select, OrderExpression ordering) + { + var orderby = new List(); + + if (select.OrderBy != null) + orderby.AddRange(select.OrderBy); + + orderby.Add(ordering); + + return select.SetOrderBy(orderby); + } + + public static SelectExpression RemoveOrderExpression(this SelectExpression select, OrderExpression ordering) + { + if (select.OrderBy != null && select.OrderBy.Count > 0) + { + var orderby = new List(select.OrderBy); + + orderby.Remove(ordering); + + return select.SetOrderBy(orderby); + } + + return select; + } + + public static SelectExpression SetGroupBy(this SelectExpression select, IEnumerable groupBy) + { + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, select.OrderBy, groupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); + } + + public static SelectExpression AddGroupExpression(this SelectExpression select, Expression expression) + { + var groupby = new List(); + + if (select.GroupBy is not null) + groupby.AddRange(select.GroupBy); + + groupby.Add(expression); + + return select.SetGroupBy(groupby); + } + + public static SelectExpression RemoveGroupExpression(this SelectExpression select, Expression expression) + { + if (select.GroupBy is not null && select.GroupBy.Any()) + { + var groupby = new List(select.GroupBy); + + groupby.Remove(expression); + + return select.SetGroupBy(groupby); + } + + return select; + } + + public static SelectExpression SetSkip(this SelectExpression select, Expression? skip) + { + if (skip != select.Skip) + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, skip, select.Take, select.IsReverse); + + return select; + } + + public static SelectExpression SetTake(this SelectExpression select, Expression? take) + { + if (take != select.Take) + return new SelectExpression(select.Alias, select.Columns, select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, take, select.IsReverse); + + return select; + } + + public static SelectExpression AddRedundantSelect(this SelectExpression sel, QueryLanguage language, Alias newAlias) + { + var newColumns = from d in sel.Columns + let qt = (d.Expression is ColumnExpression) ? ((ColumnExpression)d.Expression).QueryType : language.TypeSystem.ResolveColumnType(d.Expression.Type) + select new ColumnDeclaration(d.Name, new ColumnExpression(d.Expression.Type, qt, newAlias, d.Name), qt); + + var newFrom = new SelectExpression(newAlias, sel.Columns, sel.From, sel.Where, sel.OrderBy, sel.GroupBy, sel.IsDistinct, sel.Skip, sel.Take, sel.IsReverse); + + return new SelectExpression(sel.Alias, newColumns, newFrom, null, null, null, false, null, null, false); + } + + public static SelectExpression RemoveRedundantFrom(this SelectExpression select) + { + var fromSelect = select.From as SelectExpression; + + if (fromSelect is not null) + return Subqueries.Remove(select, fromSelect); + + return select; + } + + public static SelectExpression SetFrom(this SelectExpression select, Expression from) + { + if (select.From != from) + return new SelectExpression(select.Alias, select.Columns, from, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); + + return select; + } +} diff --git a/Connected.Expressions/Expressions/DeclarationExpression.cs b/Connected.Expressions/Expressions/DeclarationExpression.cs new file mode 100644 index 0000000..7b23086 --- /dev/null +++ b/Connected.Expressions/Expressions/DeclarationExpression.cs @@ -0,0 +1,17 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; + +namespace Connected.Expressions; + +public sealed class DeclarationExpression : CommandExpression +{ + public DeclarationExpression(IEnumerable variables, SelectExpression source) + : base(DatabaseExpressionType.Declaration, typeof(void)) + { + Variables = variables.ToReadOnly(); + Source = source; + } + + public ReadOnlyCollection Variables { get; } + public SelectExpression Source { get; } +} diff --git a/Connected.Expressions/Expressions/EntityExpression.cs b/Connected.Expressions/Expressions/EntityExpression.cs new file mode 100644 index 0000000..b69216e --- /dev/null +++ b/Connected.Expressions/Expressions/EntityExpression.cs @@ -0,0 +1,16 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class EntityExpression : DatabaseExpression +{ + public EntityExpression(Type entityType, Expression expression) + : base(DatabaseExpressionType.Entity, expression.Type) + { + EntityType = entityType; + Expression = expression; + } + + public Type EntityType { get; } + public Expression Expression { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/ExistsExpression.cs b/Connected.Expressions/Expressions/ExistsExpression.cs new file mode 100644 index 0000000..b30fa7f --- /dev/null +++ b/Connected.Expressions/Expressions/ExistsExpression.cs @@ -0,0 +1,9 @@ +namespace Connected.Expressions; + +public sealed class ExistsExpression : SubqueryExpression +{ + public ExistsExpression(SelectExpression select) + : base(DatabaseExpressionType.Exists, typeof(bool), select) + { + } +} diff --git a/Connected.Expressions/Expressions/ExpressionExtensions.cs b/Connected.Expressions/Expressions/ExpressionExtensions.cs new file mode 100644 index 0000000..ab17d63 --- /dev/null +++ b/Connected.Expressions/Expressions/ExpressionExtensions.cs @@ -0,0 +1,129 @@ +using Connected.Interop; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +internal static class ExpressionExtensions +{ + public static Expression Equal(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.Equal(left, right); + } + + public static Expression NotEqual(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.NotEqual(left, right); + } + + public static Expression GreaterThan(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.GreaterThan(left, right); + } + + public static Expression GreaterThanOrEqual(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.GreaterThanOrEqual(left, right); + } + + public static Expression LessThan(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.LessThan(left, right); + } + + public static Expression LessThanOrEqual(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.LessThanOrEqual(left, right); + } + + public static Expression And(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.And(left, right); + } + + public static Expression Or(this Expression left, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.Or(left, right); + } + + public static Expression Binary(this Expression left, ExpressionType op, Expression right) + { + ConvertExpressions(ref left, ref right); + + return Expression.MakeBinary(op, left, right); + } + + private static void ConvertExpressions(ref Expression left, ref Expression right) + { + if (left.Type != right.Type) + { + var isNullable1 = Nullables.IsNullableType(left.Type); + var isNullable2 = Nullables.IsNullableType(right.Type); + + if (isNullable1 || isNullable2) + { + if (Nullables.GetNonNullableType(left.Type) == Nullables.GetNonNullableType(right.Type)) + { + if (!isNullable1) + left = Expression.Convert(left, right.Type); + else if (!isNullable2) + right = Expression.Convert(right, left.Type); + } + } + } + } + + public static Expression[] Split(this Expression expression, params ExpressionType[] binarySeparators) + { + var list = new List(); + + Split(expression, list, binarySeparators); + + return list.ToArray(); + } + + private static void Split(Expression expression, List list, ExpressionType[] binarySeparators) + { + if (expression is not null) + { + if (binarySeparators.Contains(expression.NodeType)) + { + if (expression is BinaryExpression bex) + { + Split(bex.Left, list, binarySeparators); + Split(bex.Right, list, binarySeparators); + } + } + else + list.Add(expression); + } + } + + public static Expression? Join(this IEnumerable list, ExpressionType binarySeparator) + { + if (list is not null) + { + var array = list.ToArray(); + + if (array.Any()) + return array.Aggregate((x1, x2) => Expression.MakeBinary(binarySeparator, x1, x2)); + } + + return null; + } +} diff --git a/Connected.Expressions/Expressions/FunctionExpression.cs b/Connected.Expressions/Expressions/FunctionExpression.cs new file mode 100644 index 0000000..c1c65a2 --- /dev/null +++ b/Connected.Expressions/Expressions/FunctionExpression.cs @@ -0,0 +1,20 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class FunctionExpression : DatabaseExpression +{ + public FunctionExpression(Type type, string name, IEnumerable? arguments) + : base(DatabaseExpressionType.Function, type) + { + Name = name; + + if (arguments is not null) + Arguments = arguments.ToReadOnly(); + } + + public string Name { get; } + public ReadOnlyCollection? Arguments { get; } +} diff --git a/Connected.Expressions/Expressions/IfCommandExpression.cs b/Connected.Expressions/Expressions/IfCommandExpression.cs new file mode 100644 index 0000000..52e49e8 --- /dev/null +++ b/Connected.Expressions/Expressions/IfCommandExpression.cs @@ -0,0 +1,17 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; +public sealed class IfCommandExpression : CommandExpression +{ + public IfCommandExpression(Expression check, Expression ifTrue, Expression ifFalse) + : base(DatabaseExpressionType.If, ifTrue.Type) + { + Check = check; + IfTrue = ifTrue; + IfFalse = ifFalse; + } + + public Expression Check { get; } + public Expression IfTrue { get; } + public Expression IfFalse { get; } +} diff --git a/Connected.Expressions/Expressions/InExpression.cs b/Connected.Expressions/Expressions/InExpression.cs new file mode 100644 index 0000000..f6a9a41 --- /dev/null +++ b/Connected.Expressions/Expressions/InExpression.cs @@ -0,0 +1,24 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class InExpression : SubqueryExpression +{ + public InExpression(Expression expression, SelectExpression select) + : base(DatabaseExpressionType.In, typeof(bool), select) + { + Expression = expression; + } + + public InExpression(Expression expression, IEnumerable values) + : base(DatabaseExpressionType.In, typeof(bool), null) + { + Expression = expression; + Values = values.ToReadOnly(); + } + + public Expression Expression { get; } + public ReadOnlyCollection? Values { get; } +} diff --git a/Connected.Expressions/Expressions/IsNullExpression.cs b/Connected.Expressions/Expressions/IsNullExpression.cs new file mode 100644 index 0000000..9ce209b --- /dev/null +++ b/Connected.Expressions/Expressions/IsNullExpression.cs @@ -0,0 +1,14 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class IsNullExpression : DatabaseExpression +{ + public IsNullExpression(Expression expression) + : base(DatabaseExpressionType.IsNull, typeof(bool)) + { + Expression = expression; + } + + public Expression Expression { get; } +} diff --git a/Connected.Expressions/Expressions/JoinExpression.cs b/Connected.Expressions/Expressions/JoinExpression.cs new file mode 100644 index 0000000..3765b97 --- /dev/null +++ b/Connected.Expressions/Expressions/JoinExpression.cs @@ -0,0 +1,30 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public enum JoinType +{ + CrossJoin = 0, + InnerJoin = 1, + CrossApply = 2, + OuterApply = 3, + LeftOuter = 4, + SingletonLeftOuter = 5 +} + +public sealed class JoinExpression : DatabaseExpression +{ + public JoinExpression(JoinType joinType, Expression left, Expression right, Expression? condition) + : base(DatabaseExpressionType.Join, typeof(void)) + { + Join = joinType; + Left = left; + Right = right; + Condition = condition; + } + + public JoinType Join { get; } + public Expression Left { get; } + public Expression Right { get; } + public new Expression? Condition { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/NamedValueExpression.cs b/Connected.Expressions/Expressions/NamedValueExpression.cs new file mode 100644 index 0000000..41bbb1f --- /dev/null +++ b/Connected.Expressions/Expressions/NamedValueExpression.cs @@ -0,0 +1,25 @@ +using Connected.Expressions.Languages; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class NamedValueExpression : DatabaseExpression +{ + public NamedValueExpression(string name, DataType dataType, Expression value) + : base(DatabaseExpressionType.NamedValue, value.Type) + { + if (name is null) + throw new ArgumentNullException(nameof(name)); + + if (value is null) + throw new ArgumentNullException(nameof(value)); + + Name = name; + DataType = dataType; + Value = value; + } + + public string Name { get; } + public DataType DataType { get; } + public Expression Value { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/OrderExpression.cs b/Connected.Expressions/Expressions/OrderExpression.cs new file mode 100644 index 0000000..8ecf8b6 --- /dev/null +++ b/Connected.Expressions/Expressions/OrderExpression.cs @@ -0,0 +1,21 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public enum OrderType +{ + Ascending, + Descending +} + +public sealed class OrderExpression +{ + public OrderExpression(OrderType orderType, Expression expression) + { + OrderType = orderType; + Expression = expression; + } + + public OrderType OrderType { get; } + public Expression Expression { get; } +} diff --git a/Connected.Expressions/Expressions/OuterJoinedExpression.cs b/Connected.Expressions/Expressions/OuterJoinedExpression.cs new file mode 100644 index 0000000..c3b388f --- /dev/null +++ b/Connected.Expressions/Expressions/OuterJoinedExpression.cs @@ -0,0 +1,16 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class OuterJoinedExpression : DatabaseExpression +{ + public OuterJoinedExpression(Expression test, Expression expression) + : base(DatabaseExpressionType.OuterJoined, expression.Type) + { + Test = test; + Expression = expression; + } + + public Expression Test { get; } + public Expression Expression { get; } +} diff --git a/Connected.Expressions/Expressions/ProjectionExpression.cs b/Connected.Expressions/Expressions/ProjectionExpression.cs new file mode 100644 index 0000000..dacd5f2 --- /dev/null +++ b/Connected.Expressions/Expressions/ProjectionExpression.cs @@ -0,0 +1,32 @@ +using Connected.Expressions.Formatters; +using Connected.Expressions.Serialization; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class ProjectionExpression : DatabaseExpression +{ + public ProjectionExpression(SelectExpression source, Expression projector) + : this(source, projector, null) + { + } + + public ProjectionExpression(SelectExpression source, Expression projector, LambdaExpression? aggregator) + : base(DatabaseExpressionType.Projection, aggregator is not null ? aggregator.Body.Type : typeof(IEnumerable<>).MakeGenericType(projector.Type)) + { + Select = source; + Projector = projector; + Aggregator = aggregator; + } + + public SelectExpression Select { get; } + public Expression Projector { get; } + public LambdaExpression? Aggregator { get; } + public bool IsSingleton => Aggregator?.Body.Type == Projector.Type; + public string QueryText => SqlFormatter.Format(Select); + + public override string ToString() + { + return DatabaseSerializer.Serialize(this); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/RowNumberExpression.cs b/Connected.Expressions/Expressions/RowNumberExpression.cs new file mode 100644 index 0000000..ab951ef --- /dev/null +++ b/Connected.Expressions/Expressions/RowNumberExpression.cs @@ -0,0 +1,15 @@ +using Connected.Expressions.Collections; +using System.Collections.ObjectModel; + +namespace Connected.Expressions; + +public sealed class RowNumberExpression : DatabaseExpression +{ + public RowNumberExpression(IEnumerable? orderBy) + : base(DatabaseExpressionType.RowCount, typeof(int)) + { + OrderBy = orderBy is null ? new List().AsReadOnly() : orderBy.ToReadOnly(); + } + + public ReadOnlyCollection OrderBy { get; } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/ScalarExpression.cs b/Connected.Expressions/Expressions/ScalarExpression.cs new file mode 100644 index 0000000..ca0a3d7 --- /dev/null +++ b/Connected.Expressions/Expressions/ScalarExpression.cs @@ -0,0 +1,9 @@ +namespace Connected.Expressions; + +public sealed class ScalarExpression : SubqueryExpression +{ + public ScalarExpression(Type type, SelectExpression select) + : base(DatabaseExpressionType.Scalar, type, select) + { + } +} diff --git a/Connected.Expressions/Expressions/SelectExpression.cs b/Connected.Expressions/Expressions/SelectExpression.cs new file mode 100644 index 0000000..9ef1e30 --- /dev/null +++ b/Connected.Expressions/Expressions/SelectExpression.cs @@ -0,0 +1,46 @@ +using Connected.Expressions.Collections; +using Connected.Expressions.Formatters; +using Connected.Expressions.Translation; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class SelectExpression : AliasedExpression +{ + public SelectExpression(Alias alias, IEnumerable columns, Expression from, Expression? where, + IEnumerable? orderBy, IEnumerable? groupBy, bool isDistinct, Expression? skip, Expression? take, bool isReverse) + : base(DatabaseExpressionType.Select, typeof(void), alias) + { + Columns = columns.ToReadOnly(); + IsDistinct = isDistinct; + From = from; + Where = where; + OrderBy = orderBy?.ToReadOnly(); + GroupBy = groupBy?.ToReadOnly(); + Take = take; + Skip = skip; + IsReverse = isReverse; + } + + public SelectExpression(Alias alias, IEnumerable columns, Expression from, Expression? where, IEnumerable? orderBy, IEnumerable? groupBy) + : this(alias, columns, from, where, orderBy, groupBy, false, null, null, false) + { + } + + public SelectExpression(Alias alias, IEnumerable columns, Expression from, Expression? where) + : this(alias, columns, from, where, null, null) + { + } + + public ReadOnlyCollection Columns { get; } + public Expression From { get; } + public Expression? Where { get; } + public ReadOnlyCollection? OrderBy { get; } + public ReadOnlyCollection? GroupBy { get; } + public bool IsDistinct { get; } + public Expression? Skip { get; } + public Expression? Take { get; } + public bool IsReverse { get; } + public string QueryText => SqlFormatter.Format(this); +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/SubqueryExpression.cs b/Connected.Expressions/Expressions/SubqueryExpression.cs new file mode 100644 index 0000000..18f7d79 --- /dev/null +++ b/Connected.Expressions/Expressions/SubqueryExpression.cs @@ -0,0 +1,14 @@ +namespace Connected.Expressions; + +public class SubqueryExpression : DatabaseExpression +{ + protected SubqueryExpression(DatabaseExpressionType eType, Type type, SelectExpression? select) + : base(eType, type) + { + System.Diagnostics.Debug.Assert(eType == DatabaseExpressionType.Scalar || eType == DatabaseExpressionType.Exists || eType == DatabaseExpressionType.In); + + Select = select; + } + + public SelectExpression? Select { get; } +} diff --git a/Connected.Expressions/Expressions/TableExpression.cs b/Connected.Expressions/Expressions/TableExpression.cs new file mode 100644 index 0000000..05c9277 --- /dev/null +++ b/Connected.Expressions/Expressions/TableExpression.cs @@ -0,0 +1,23 @@ +using Connected.Expressions.Translation; + +namespace Connected.Expressions; + +public sealed class TableExpression : AliasedExpression +{ + public TableExpression(Alias alias, Type entity, string schema, string name) + : base(DatabaseExpressionType.Table, typeof(void), alias) + { + Entity = entity; + Name = name; + Schema = schema; + } + + public Type Entity { get; } + public string Name { get; } + public string Schema { get; } + + public override string ToString() + { + return $"T({Schema}.{Name})"; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Expressions/VariableDeclaration.cs b/Connected.Expressions/Expressions/VariableDeclaration.cs new file mode 100644 index 0000000..945ca10 --- /dev/null +++ b/Connected.Expressions/Expressions/VariableDeclaration.cs @@ -0,0 +1,18 @@ +using Connected.Expressions.Languages; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public class VariableDeclaration +{ + public VariableDeclaration(string name, DataType dataType, Expression expression) + { + Name = name; + DataType = dataType; + Expression = expression; + } + + public string Name { get; } + public DataType DataType { get; } + public Expression Expression { get; } +} diff --git a/Connected.Expressions/Expressions/VariableExpression.cs b/Connected.Expressions/Expressions/VariableExpression.cs new file mode 100644 index 0000000..ea9f8fd --- /dev/null +++ b/Connected.Expressions/Expressions/VariableExpression.cs @@ -0,0 +1,19 @@ +using Connected.Expressions.Languages; +using System.Linq.Expressions; + +namespace Connected.Expressions; + +public sealed class VariableExpression : Expression +{ + public VariableExpression(string name, Type type, DataType dataType) + { + Name = name; + Type = type; + DataType = dataType; + } + + public string Name { get; } + public override Type Type { get; } + public DataType DataType { get; } + public override ExpressionType NodeType => (ExpressionType)DatabaseExpressionType.Variable; +} \ No newline at end of file diff --git a/Connected.Expressions/Formatters/SqlFormatter.cs b/Connected.Expressions/Formatters/SqlFormatter.cs new file mode 100644 index 0000000..95b4d2c --- /dev/null +++ b/Connected.Expressions/Formatters/SqlFormatter.cs @@ -0,0 +1,984 @@ +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; +using Connected.Expressions.Visitors; +using System.Collections.ObjectModel; +using System.Globalization; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; + +namespace Connected.Expressions.Formatters; + +public class SqlFormatter : DatabaseVisitor +{ + protected const char Space = ' '; + protected const char Period = '.'; + protected const char OpenBracket = '('; + protected const char CloseBracket = ')'; + protected const char SingleQuote = '\''; + protected enum Indentation + { + Same, + Inner, + Outer + } + + protected SqlFormatter(QueryLanguage? language) + { + Language = language; + Text = new StringBuilder(); + Aliases = new(); + } + + private int Depth { get; set; } + protected virtual QueryLanguage? Language { get; } + protected bool HideColumnAliases { get; set; } + protected bool HideTableAliases { get; set; } + protected bool IsNested { get; set; } + public int IndentationWidth { get; set; } = 2; + private StringBuilder Text { get; } + private Dictionary Aliases { get; } + + public static string Format(Expression expression) + { + var formatter = new SqlFormatter(null); + + formatter.Visit(expression); + + return formatter.ToString(); + } + + public override string ToString() + { + return Text.ToString(); + } + + protected void Write(object value) + { + Text.Append(value); + } + + protected virtual void WriteParameterName(string name) + { + Write($"@{name}"); + } + + protected virtual void WriteVariableName(string name) + { + WriteParameterName(name); + } + + protected virtual void WriteAsAliasName(string aliasName) + { + Write("AS "); + WriteAliasName(aliasName); + } + + protected virtual void WriteAliasName(string aliasName) + { + Write(aliasName); + } + + protected virtual void WriteAsColumnName(string columnName) + { + Write("AS "); + WriteColumnName(columnName); + } + + protected virtual void WriteColumnName(string columnName) + { + var name = Language is not null ? Language.Quote(columnName) : columnName; + + Write(name); + } + + protected virtual void WriteTableName(string tableSchema, string tableName) + { + var name = Language is not null ? Language.Quote(tableName) : tableName; + var schema = Language is not null ? Language.Quote(tableSchema) : tableName; + + Write($"{schema}.{name}"); + } + + protected void WriteLine(Indentation style) + { + Text.AppendLine(); + Indent(style); + + for (var i = 0; i < Depth * IndentationWidth; i++) + Write(Space); + } + + protected void Indent(Indentation style) + { + if (style == Indentation.Inner) + Depth++; + else if (style == Indentation.Outer) + Depth--; + } + + protected virtual string GetAliasName(Alias alias) + { + if (!Aliases.TryGetValue(alias, out string? name)) + { + name = $"A{alias.GetHashCode()}?"; + + Aliases.Add(alias, name); + } + + return name; + } + + protected void AddAlias(Alias alias) + { + if (!Aliases.TryGetValue(alias, out _)) + { + var name = $"t{Aliases.Count}"; + + Aliases.Add(alias, name); + } + } + + protected virtual void AddAliases(Expression expr) + { + if (expr as AliasedExpression is AliasedExpression ax) + AddAlias(ax.Alias); + else + { + if (expr as JoinExpression is JoinExpression jx) + { + AddAliases(jx.Left); + AddAliases(jx.Right); + } + } + } + + protected override Expression? Visit(Expression? exp) + { + if (exp is null) + return null; + + return exp.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Not or ExpressionType.Convert or ExpressionType.ConvertChecked + or ExpressionType.UnaryPlus or ExpressionType.Add or ExpressionType.AddChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked + or ExpressionType.Multiply or ExpressionType.MultiplyChecked or ExpressionType.Divide or ExpressionType.Modulo or ExpressionType.And + or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.Coalesce or ExpressionType.RightShift or ExpressionType.LeftShift or ExpressionType.ExclusiveOr + or ExpressionType.Power or ExpressionType.Conditional or ExpressionType.Constant or ExpressionType.MemberAccess or ExpressionType.Call + or ExpressionType.New or (ExpressionType)DatabaseExpressionType.Table or (ExpressionType)DatabaseExpressionType.Column + or (ExpressionType)DatabaseExpressionType.Select or (ExpressionType)DatabaseExpressionType.Join or (ExpressionType)DatabaseExpressionType.Aggregate + or (ExpressionType)DatabaseExpressionType.Scalar or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In + or (ExpressionType)DatabaseExpressionType.AggregateSubquery or (ExpressionType)DatabaseExpressionType.IsNull + or (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.RowCount + or (ExpressionType)DatabaseExpressionType.Projection or (ExpressionType)DatabaseExpressionType.NamedValue + or (ExpressionType)DatabaseExpressionType.Block or (ExpressionType)DatabaseExpressionType.If or (ExpressionType)DatabaseExpressionType.Declaration + or (ExpressionType)DatabaseExpressionType.Variable or (ExpressionType)DatabaseExpressionType.Function => base.Visit(exp), + _ => throw new NotSupportedException($"The expression node of type '{DatabaseExpressionExtensions.ResolveNodeTypeName(exp)}' is not supported."), + }; + } + + protected override Expression VisitMemberAccess(MemberExpression m) + { + throw new NotSupportedException($"The member access '{m.Member}' is not supported."); + } + + protected override Expression? VisitMethodCall(MethodCallExpression m) + { + if (m.Method.DeclaringType == typeof(decimal)) + { + switch (m.Method.Name) + { + case "Add": + case "Subtract": + case "Multiply": + case "Divide": + case "Remainder": + Write(OpenBracket); + VisitValue(m.Arguments[0]); + Write(Space); + Write(GetOperator(m.Method.Name)); + Write(Space); + VisitValue(m.Arguments[1]); + Write(CloseBracket); + return m; + case "Negate": + Write('-'); + Visit(m.Arguments[0]); + Write(string.Empty); + return m; + case "Compare": + Visit(Expression.Condition( + Expression.Equal(m.Arguments[0], m.Arguments[1]), + Expression.Constant(0), + Expression.Condition( + Expression.LessThan(m.Arguments[0], m.Arguments[1]), + Expression.Constant(-1), + Expression.Constant(1) + ))); + + return m; + } + } + else if (string.Equals(m.Method.Name, "ToString", StringComparison.Ordinal) && m.Object?.Type == typeof(string)) + { + /* + * no op. + */ + return Visit(m.Object); + } + else if (string.Equals(m.Method.Name, "Equals", StringComparison.Ordinal)) + { + if (m.Method.IsStatic && m.Method.DeclaringType == typeof(object) || m.Method.DeclaringType == typeof(string)) + { + Write(OpenBracket); + Visit(m.Arguments[0]); + Write(" = "); + Visit(m.Arguments[1]); + Write(CloseBracket); + return m; + } + else if (!m.Method.IsStatic && m.Arguments.Count == 1 && m.Arguments[0].Type == m.Object?.Type) + { + Write(OpenBracket); + Visit(m.Object); + Write(" = "); + Visit(m.Arguments[0]); + Write(CloseBracket); + return m; + } + } + + throw new NotSupportedException($"The method '{m.Method.Name}' is not supported"); + } + + protected virtual bool IsInteger(Type type) + { + return Interop.TypeSystem.IsInteger(type); + } + + protected override NewExpression VisitNew(NewExpression nex) + { + throw new NotSupportedException($"The constructor for '{nex.Constructor?.DeclaringType}' is not supported"); + } + + protected override Expression VisitUnary(UnaryExpression u) + { + var op = GetOperator(u); + + switch (u.NodeType) + { + case ExpressionType.Not: + + if (u.Operand is IsNullExpression nullExpression) + { + Visit(nullExpression.Expression); + Write(" IS NOT NULL"); + } + else if (IsBoolean(u.Operand.Type) || op.Length > 1) + { + Write(op); + Write(Space); + VisitPredicate(u.Operand); + } + else + { + Write(op); + VisitValue(u.Operand); + } + + break; + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + Write(op); + VisitValue(u.Operand); + break; + case ExpressionType.UnaryPlus: + VisitValue(u.Operand); + break; + case ExpressionType.Convert: + /* + * ignore conversions for now + */ + Visit(u.Operand); + break; + default: + throw new NotSupportedException($"The unary operator '{u.NodeType}' is not supported"); + } + + return u; + } + + protected override Expression VisitBinary(BinaryExpression b) + { + var op = GetOperator(b); + var left = b.Left; + var right = b.Right; + + Write(OpenBracket); + + switch (b.NodeType) + { + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + if (IsBoolean(left.Type)) + { + VisitPredicate(left); + Write(Space); + Write(op); + Write(Space); + VisitPredicate(right); + } + else + { + VisitValue(left); + Write(Space); + Write(op); + Write(Space); + VisitValue(right); + } + break; + case ExpressionType.Equal: + if (right.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)right; + + if (ce.Value is null) + { + Visit(left); + Write(" IS NULL"); + + break; + } + } + else if (left.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)left; + + if (ce.Value is null) + { + Visit(right); + Write(" IS NULL"); + + break; + } + } + goto case ExpressionType.LessThan; + case ExpressionType.NotEqual: + if (right.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)right; + + if (ce.Value is null) + { + Visit(left); + Write(" IS NOT NULL"); + + break; + } + } + else if (left.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)left; + + if (ce.Value is null) + { + Visit(right); + Write(" IS NOT NULL"); + + break; + } + } + goto case ExpressionType.LessThan; + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + /* + * check for special x.CompareTo(y) && type.Compare(x,y) + */ + if (left.NodeType == ExpressionType.Call && right.NodeType == ExpressionType.Constant) + { + var mc = (MethodCallExpression)left; + var ce = (ConstantExpression)right; + + if (ce.Value is not null && ce.Value.GetType() == typeof(int) && (int)ce.Value == 0) + { + if (string.Equals(mc.Method.Name, "CompareTo", StringComparison.Ordinal) && !mc.Method.IsStatic && mc.Arguments.Count == 1) + { + left = mc.Object; + right = mc.Arguments[0]; + } + else if ((mc.Method.DeclaringType == typeof(string) || mc.Method.DeclaringType == typeof(decimal)) + && string.Equals(mc.Method.Name, "Compare", StringComparison.Ordinal) && mc.Method.IsStatic && mc.Arguments.Count == 2) + { + left = mc.Arguments[0]; + right = mc.Arguments[1]; + } + } + } + goto case ExpressionType.Add; + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + + if (left is not null) + VisitValue(left); + + Write(Space); + Write(op); + Write(Space); + VisitValue(right); + break; + default: + throw new NotSupportedException($"The binary operator '{b.NodeType}' is not supported"); + } + + Write(CloseBracket); + + return b; + } + + protected virtual string GetOperator(string methodName) + { + return methodName switch + { + "Add" => "+", + "Subtract" => "-", + "Multiply" => "*", + "Divide" => "/", + "Negate" => "-", + "Remainder" => "%", + _ => string.Empty, + }; + } + + protected virtual string GetOperator(UnaryExpression u) + { + return u.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked => "-", + ExpressionType.UnaryPlus => "+", + ExpressionType.Not => IsBoolean(u.Operand.Type) ? "NOT" : "~", + _ => string.Empty, + }; + } + + protected virtual string GetOperator(BinaryExpression b) + { + return b.NodeType switch + { + ExpressionType.And or ExpressionType.AndAlso => IsBoolean(b.Left.Type) ? "AND" : "&", + ExpressionType.Or or ExpressionType.OrElse => IsBoolean(b.Left.Type) ? "OR" : "|", + ExpressionType.Equal => "=", + ExpressionType.NotEqual => "<>", + ExpressionType.LessThan => "<", + ExpressionType.LessThanOrEqual => "<=", + ExpressionType.GreaterThan => ">", + ExpressionType.GreaterThanOrEqual => ">=", + ExpressionType.Add or ExpressionType.AddChecked => "+", + ExpressionType.Subtract or ExpressionType.SubtractChecked => "-", + ExpressionType.Multiply or ExpressionType.MultiplyChecked => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.ExclusiveOr => "^", + ExpressionType.LeftShift => "<<", + ExpressionType.RightShift => ">>", + _ => string.Empty, + }; + } + + protected virtual bool IsBoolean(Type type) + { + return type == typeof(bool) || type == typeof(bool?); + } + + protected virtual bool IsPredicate(Expression expr) + { + return expr.NodeType switch + { + ExpressionType.And or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse => IsBoolean(((BinaryExpression)expr).Type), + ExpressionType.Not => IsBoolean(((UnaryExpression)expr).Type), + ExpressionType.Equal or ExpressionType.NotEqual or ExpressionType.LessThan or ExpressionType.LessThanOrEqual or + ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or (ExpressionType)DatabaseExpressionType.IsNull or + (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In => true, + ExpressionType.Call => IsBoolean(((MethodCallExpression)expr).Type), + _ => false, + }; + } + + protected virtual Expression VisitPredicate(Expression expr) + { + Visit(expr); + + if (!IsPredicate(expr)) + Write(" <> 0"); + + return expr; + } + + protected virtual Expression? VisitValue(Expression expr) + { + return Visit(expr); + } + + protected override Expression VisitConditional(ConditionalExpression c) + { + throw new NotSupportedException("Conditional expressions not supported."); + } + + protected override Expression VisitConstant(ConstantExpression c) + { + WriteValue(c.Value); + + return c; + } + + protected virtual void WriteValue(object? value) + { + if (value is null) + Write("NULL"); + else if (value.GetType().GetTypeInfo().IsEnum) + Write(Convert.ChangeType(value, Enum.GetUnderlyingType(value.GetType()))); + else + { + switch (Interop.TypeSystem.GetTypeCode(value.GetType())) + { + case TypeCode.Boolean: + Write((bool)value ? 1 : 0); + break; + case TypeCode.String: + Write(SingleQuote); + Write(value); + Write(SingleQuote); + break; + case TypeCode.Object: + throw new NotSupportedException($"The constant for '{value}' is not supported"); + case TypeCode.Single: + case TypeCode.Double: + var str = ((IConvertible)value).ToString(NumberFormatInfo.InvariantInfo); + + if (!str.Contains(Period)) + str = string.Concat(str, $"{Period}0"); + + Write(str); + break; + default: + Write((value as IConvertible)?.ToString(CultureInfo.InvariantCulture) ?? value); + break; + } + } + } + protected override Expression VisitColumn(ColumnExpression column) + { + if (column.Alias is not null && !HideColumnAliases) + { + WriteAliasName(GetAliasName(column.Alias)); + Write(Period); + } + + WriteColumnName(column.Name); + + return column; + } + protected override Expression VisitProjection(ProjectionExpression proj) + { + // treat these like scalar subqueries + if (proj.Projector is ColumnExpression) + { + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(proj.Select); + Write(CloseBracket); + Indent(Indentation.Outer); + } + else + throw new NotSupportedException("Non-scalar projections cannot be translated to SQL."); + + return proj; + } + + protected override Expression VisitSelect(SelectExpression select) + { + AddAliases(select.From); + Write("SELECT "); + + if (select.IsDistinct) + Write("DISTINCT "); + + if (select.Take is not null) + WriteTopClause(select.Take); + + WriteColumns(select.Columns); + + if (select.From is not null) + { + WriteLine(Indentation.Same); + Write("FROM "); + VisitSource(select.From); + } + + if (select.Where is not null) + { + WriteLine(Indentation.Same); + Write("WHERE "); + VisitPredicate(select.Where); + } + + if (select.GroupBy is not null && select.GroupBy.Any()) + { + WriteLine(Indentation.Same); + Write("GROUP BY "); + + for (var i = 0; i < select.GroupBy.Count; i++) + { + if (i > 0) + Write(", "); + + VisitValue(select.GroupBy[i]); + } + } + + if (select.OrderBy is not null && select.OrderBy.Any()) + { + WriteLine(Indentation.Same); + Write("ORDER BY "); + + for (var i = 0; i < select.OrderBy.Count; i++) + { + var exp = select.OrderBy[i]; + + if (i > 0) + Write(", "); + + VisitValue(exp.Expression); + + if (exp.OrderType != OrderType.Ascending) + Write(" DESC"); + } + } + + return select; + } + + protected virtual void WriteTopClause(Expression expression) + { + Write("TOP ("); + Visit(expression); + Write(") "); + } + + protected virtual void WriteColumns(ReadOnlyCollection columns) + { + if (columns.Any()) + { + for (var i = 0; i < columns.Count; i++) + { + var column = columns[i]; + + if (i > 0) + Write(", "); + + var c = VisitValue(column.Expression) as ColumnExpression; + + if (!string.IsNullOrEmpty(column.Name) && (c is null || !string.Equals(c.Name, column.Name, StringComparison.OrdinalIgnoreCase))) + { + Write(Space); + WriteAsColumnName(column.Name); + } + } + } + else + { + Write("NULL "); + + if (IsNested) + { + WriteAsColumnName("tmp"); + Write(Space); + } + } + } + + protected override Expression VisitSource(Expression source) + { + var saveIsNested = IsNested; + + IsNested = true; + + switch ((DatabaseExpressionType)source.NodeType) + { + case DatabaseExpressionType.Table: + var table = (TableExpression)source; + + WriteTableName(table.Schema, table.Name); + + if (!HideTableAliases) + { + Write(Space); + WriteAsAliasName(GetAliasName(table.Alias)); + } + break; + case DatabaseExpressionType.Select: + var select = (SelectExpression)source; + + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(select); + WriteLine(Indentation.Same); + Write($"{CloseBracket} "); + WriteAsAliasName(GetAliasName(select.Alias)); + Indent(Indentation.Outer); + break; + case DatabaseExpressionType.Join: + VisitJoin((JoinExpression)source); + break; + default: + throw new InvalidOperationException("Select source is not valid type"); + } + + IsNested = saveIsNested; + + return source; + } + + protected override Expression VisitJoin(JoinExpression join) + { + VisitJoinLeft(join.Left); + WriteLine(Indentation.Same); + + switch (join.Join) + { + case JoinType.CrossJoin: + Write("CROSS JOIN "); + break; + case JoinType.InnerJoin: + Write("INNER JOIN "); + break; + case JoinType.CrossApply: + Write("CROSS APPLY "); + break; + case JoinType.OuterApply: + Write("OUTER APPLY "); + break; + case JoinType.LeftOuter: + case JoinType.SingletonLeftOuter: + Write("LEFT OUTER JOIN "); + break; + } + + VisitJoinRight(join.Right); + + if (join.Condition is not null) + { + WriteLine(Indentation.Inner); + Write("ON "); + VisitPredicate(join.Condition); + Indent(Indentation.Outer); + } + + return join; + } + + protected virtual Expression VisitJoinLeft(Expression source) + { + return VisitSource(source); + } + + protected virtual Expression VisitJoinRight(Expression source) + { + return VisitSource(source); + } + + protected virtual void WriteAggregateName(string aggregateName) + { + switch (aggregateName) + { + case "Average": + Write("AVG"); + break; + case "LongCount": + Write("COUNT"); + break; + default: + Write(aggregateName.ToUpper()); + break; + } + } + + protected virtual bool RequiresAsteriskWhenNoArgument(string aggregateName) + { + return string.Equals(aggregateName, "Count", StringComparison.Ordinal) || string.Equals(aggregateName, "LongCount", StringComparison.Ordinal); + } + + protected override Expression VisitAggregate(AggregateExpression aggregate) + { + WriteAggregateName(aggregate.AggregateName); + Write(OpenBracket); + + if (aggregate.IsDistinct) + Write("DISTINCT "); + + if (aggregate.Argument is not null) + VisitValue(aggregate.Argument); + else if (RequiresAsteriskWhenNoArgument(aggregate.AggregateName)) + Write("*"); + + Write(CloseBracket); + + return aggregate; + } + + protected override Expression VisitIsNull(IsNullExpression isnull) + { + VisitValue(isnull.Expression); + Write(" IS NULL"); + + return isnull; + } + + protected override Expression VisitBetween(BetweenExpression between) + { + VisitValue(between.Expression); + Write(" BETWEEN "); + VisitValue(between.Lower); + Write(" AND "); + VisitValue(between.Upper); + + return between; + } + + protected override Expression VisitRowNumber(RowNumberExpression rowNumber) + { + throw new NotSupportedException(); + } + + protected override Expression VisitScalar(ScalarExpression subquery) + { + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(subquery.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + + return subquery; + } + + protected override Expression VisitExists(ExistsExpression exists) + { + Write($"EXISTS{OpenBracket}"); + WriteLine(Indentation.Inner); + Visit(exists.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + + return exists; + } + protected override Expression VisitIn(InExpression @in) + { + if (@in.Values is not null) + { + if (!@in.Values.Any()) + Write("0 <> 0"); + else + { + VisitValue(@in.Expression); + Write($" IN {OpenBracket}"); + + for (var i = 0; i < @in.Values.Count; i++) + { + if (i > 0) + Write(", "); + + VisitValue(@in.Values[i]); + } + + Write(CloseBracket); + } + } + else + { + VisitValue(@in.Expression); + Write($" IN {OpenBracket}"); + WriteLine(Indentation.Inner); + Visit(@in.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + } + + return @in; + } + + protected override Expression VisitNamedValue(NamedValueExpression value) + { + WriteParameterName(value.Name); + + return value; + } + + protected override Expression VisitIf(IfCommandExpression ifx) + { + throw new NotSupportedException(); + } + + protected override Expression VisitBlock(BlockExpression block) + { + throw new NotSupportedException(); + } + + protected override Expression VisitDeclaration(DeclarationExpression declaration) + { + throw new NotSupportedException(nameof(declaration)); + } + + protected override Expression VisitVariable(VariableExpression vex) + { + WriteVariableName(vex.Name); + + return vex; + } + + protected virtual void VisitStatement(Expression expression) + { + if (expression is ProjectionExpression p) + Visit(p.Select); + else + Visit(expression); + } + + protected override Expression VisitFunction(FunctionExpression func) + { + Write(func.Name); + + if (func.Arguments is not null && func.Arguments.Any()) + { + Write(OpenBracket); + + for (var i = 0; i < func.Arguments.Count; i++) + { + if (i > 0) + Write(", "); + + Visit(func.Arguments[i]); + } + + Write(CloseBracket); + } + + return func; + } +} \ No newline at end of file diff --git a/Connected.Expressions/IStorage.cs b/Connected.Expressions/IStorage.cs new file mode 100644 index 0000000..d3b0298 --- /dev/null +++ b/Connected.Expressions/IStorage.cs @@ -0,0 +1,4 @@ +namespace Connected.Expressions; +internal interface IStorage : IQueryable, IQueryable +{ +} diff --git a/Connected.Expressions/Languages/DataType.cs b/Connected.Expressions/Languages/DataType.cs new file mode 100644 index 0000000..10f1a84 --- /dev/null +++ b/Connected.Expressions/Languages/DataType.cs @@ -0,0 +1,9 @@ +namespace Connected.Expressions.Languages; + +public abstract class DataType +{ + public bool NotNull { get; protected set; } + public int Length { get; protected set; } + public short Precision { get; protected set; } + public short Scale { get; protected set; } +} diff --git a/Connected.Expressions/Languages/Linguist.cs b/Connected.Expressions/Languages/Linguist.cs new file mode 100644 index 0000000..d4430ff --- /dev/null +++ b/Connected.Expressions/Languages/Linguist.cs @@ -0,0 +1,76 @@ +using Connected.Expressions.Formatters; +using Connected.Expressions.Translation; +using Connected.Expressions.Translation.Optimization; +using Connected.Expressions.Translation.Rewriters; +using System.Linq.Expressions; + +namespace Connected.Expressions.Languages; + +public class Linguist +{ + /// + /// Construct a + /// + public Linguist(ExpressionCompilationContext context, QueryLanguage language, Translator translator) + { + Context = context; + Language = language; + Translator = translator; + } + + protected ExpressionCompilationContext Context { get; } + public QueryLanguage Language { get; } + public Translator Translator { get; } + /// + /// Provides language specific query translation. Use this to apply language specific rewrites or + /// to make assertions/validations about the query. + /// + public virtual Expression Translate(Expression expression) + { + /* + * remove redundant layers again before cross apply rewrite + */ + expression = UnusedColumns.Remove(expression); + expression = RedundantColumns.Remove(expression); + expression = RedundantSubqueries.Remove(expression); + /* + * convert cross-apply and outer-apply joins into inner & left-outer-joins if possible + */ + var rewritten = CrossApplyRewriter.Rewrite(this.Language, expression); + /* + * convert cross joins into inner joins + */ + rewritten = CrossJoinRewriter.Rewrite(rewritten); + + if (rewritten != expression) + { + expression = rewritten; + /* + * do final reduction + */ + expression = UnusedColumns.Remove(expression); + expression = RedundantSubqueries.Remove(expression); + expression = RedundantJoins.Remove(expression); + expression = RedundantColumns.Remove(expression); + } + + return expression; + } + /// + /// Converts the query expression into text of this query language + /// + public virtual string Format(Expression expression) + { + /* + * use common SQL formatter by default + */ + return SqlFormatter.Format(expression); + } + /// + /// Determine which sub-expressions must be parameters + /// + public virtual Expression Parameterize(Expression expression) + { + return Parameterizer.Parameterize(Language, expression); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Languages/QueryLanguage.cs b/Connected.Expressions/Languages/QueryLanguage.cs new file mode 100644 index 0000000..fcfcb3f --- /dev/null +++ b/Connected.Expressions/Languages/QueryLanguage.cs @@ -0,0 +1,190 @@ +using Connected.Expressions.Translation; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.TypeSystem; +using Connected.Interop; +using System.Collections; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Expressions.Languages; + +public abstract class QueryLanguage +{ + private const string AggregateCount = "Count"; + private const string AggregateLongCount = "LongCount"; + private const string AggregateSum = "Sum"; + private const string AggregateMin = "Min"; + private const string AggregateMax = "Max"; + private const string AggregateAverage = "Average"; + + public virtual bool AllowDistinctInAggregates => false; + public abstract QueryTypeSystem TypeSystem { get; } + public virtual bool AllowsMultipleCommands => false; + public virtual bool AllowSubqueryInSelectWithoutFrom => false; + + public virtual Expression GetRowsAffectedExpression(Expression command) + { + return new FunctionExpression(typeof(int), "@@ROWCOUNT", null); + } + + public virtual bool IsRowsAffectedExpressions(Expression expression) + { + var fex = expression as FunctionExpression; + + return fex is not null && string.Equals(fex.Name, "@@ROWCOUNT", StringComparison.OrdinalIgnoreCase); + } + + public virtual string Quote(string name) + { + return name; + } + + public virtual bool IsAggregate(MemberInfo member) + { + var method = member as MethodInfo; + + if (method is not null) + { + if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(Enumerable)) + { + switch (method.Name) + { + case AggregateCount: + case AggregateLongCount: + case AggregateSum: + case AggregateMin: + case AggregateMax: + case AggregateAverage: + return true; + } + } + } + + var property = member as PropertyInfo; + + if (property is not null && string.Equals(property.Name, AggregateCount, StringComparison.Ordinal) && typeof(IEnumerable).IsAssignableFrom(property.DeclaringType)) + return true; + + return false; + } + + public virtual bool IsAggregateArgumentPredicate(string aggregateName) + { + return string.Equals(aggregateName, AggregateCount, StringComparison.Ordinal) || string.Equals(aggregateName, AggregateLongCount, StringComparison.Ordinal); + } + + public virtual Expression GetOuterJoinTest(SelectExpression select) + { + /* + * if the column is used in the join condition (equality test) + * if it is null in the database then the join test won't match (null != null) so the row won't appear + * we can safely use this existing column as our test to determine if the outer join produced a row + * + * find a column that is used in equality test + */ + var aliases = DeclaredAliasesResolver.Resolve(select.From); + var columns = JoinColumnResolver.Resolve(aliases, select).ToList(); + + if (columns.Any()) + { + /* + * prefer one that is already in the projection list. + */ + foreach (var column in columns) + { + foreach (var col in select.Columns) + { + if (column.Equals(col.Expression)) + return column; + } + } + + return columns[0]; + } + /* + * fall back to introducing a constant + */ + return Expression.Constant(1, typeof(int?)); + } + public virtual ProjectionExpression AddOuterJoinTest(ProjectionExpression expression) + { + var test = GetOuterJoinTest(expression.Select); + var select = expression.Select; + ColumnExpression? testCol = null; + /* + * look to see if test expression exists in columns already + */ + foreach (var column in select.Columns) + { + if (test.Equals(column.Expression)) + { + var colType = TypeSystem.ResolveColumnType(test.Type); + + testCol = new ColumnExpression(test.Type, colType, select.Alias, column.Name); + + break; + } + } + + if (testCol is null) + { + /* + * add expression to projection + */ + testCol = test as ColumnExpression; + + var colName = (testCol is not null) ? testCol.Name : "Test"; + + colName = expression.Select.Columns.ResolveAvailableColumnName(colName); + + var colType = TypeSystem.ResolveColumnType(test.Type); + + select = select.AddColumn(new ColumnDeclaration(colName, test, colType)); + testCol = new ColumnExpression(test.Type, colType, select.Alias, colName); + } + + var newProjector = new OuterJoinedExpression(testCol, expression.Projector); + + return new ProjectionExpression(select, newProjector, expression.Aggregator); + } + + public virtual bool IsScalar(Type type) + { + type = Nullables.GetNonNullableType(type); + + return Interop.TypeSystem.GetTypeCode(type) switch + { + TypeCode.Empty => false, + TypeCode.Object => type == typeof(DateTimeOffset) || + type == typeof(TimeSpan) || + type == typeof(Guid) || + type == typeof(byte[]), + _ => true, + }; + } + + /// + /// Determines whether the given expression can be represented as a column in a select expressionss + /// + public virtual bool CanBeColumn(Expression expression) + { + return MustBeColumn(expression) || IsScalar(expression.Type); + } + /// + /// Determines whether the given expression must be represented as a column in a SELECT column list + /// + public virtual bool MustBeColumn(Expression expression) + { + return expression.NodeType switch + { + (ExpressionType)DatabaseExpressionType.Column or (ExpressionType)DatabaseExpressionType.Scalar or (ExpressionType)DatabaseExpressionType.Exists or + (ExpressionType)DatabaseExpressionType.AggregateSubquery or (ExpressionType)DatabaseExpressionType.Aggregate => true, + _ => false, + }; + } + + public virtual Linguist CreateLinguist(ExpressionCompilationContext context, Translator translator) + { + return new Linguist(context, this, translator); + } +} diff --git a/Connected.Expressions/Mappings/ConstructorBindResult.cs b/Connected.Expressions/Mappings/ConstructorBindResult.cs new file mode 100644 index 0000000..9147571 --- /dev/null +++ b/Connected.Expressions/Mappings/ConstructorBindResult.cs @@ -0,0 +1,16 @@ +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using Connected.Expressions.Collections; + +namespace Connected.Expressions.Mappings; +internal class ConstructorBindResult +{ + public ConstructorBindResult(NewExpression expression, IEnumerable remaining) + { + Expression = expression; + Remaining = remaining.ToReadOnly(); + } + + public NewExpression Expression { get; } + public ReadOnlyCollection Remaining { get; } +} diff --git a/Connected.Expressions/Mappings/EntityAssignment.cs b/Connected.Expressions/Mappings/EntityAssignment.cs new file mode 100644 index 0000000..d3ce80f --- /dev/null +++ b/Connected.Expressions/Mappings/EntityAssignment.cs @@ -0,0 +1,14 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions.Mappings; +internal sealed class EntityAssignment +{ + public EntityAssignment(MemberMapping mapping, Expression expression) + { + Mapping = mapping; + Expression = expression; + } + + public MemberMapping Mapping { get; } + public Expression Expression { get; } +} diff --git a/Connected.Expressions/Mappings/EntityMapping.cs b/Connected.Expressions/Mappings/EntityMapping.cs new file mode 100644 index 0000000..7b73eb5 --- /dev/null +++ b/Connected.Expressions/Mappings/EntityMapping.cs @@ -0,0 +1,193 @@ +using Connected.Collections; +using Connected.Entities.Annotations; +using Connected.Expressions.Expressions; +using Connected.Expressions.Reflection; +using Connected.Expressions.Translation; +using Connected.Expressions.Translation.Projections; +using Connected.Interop; +using System.Collections.Immutable; +using System.Linq.Expressions; +using System.Reflection; +using Binder = Connected.Expressions.Translation.Binder; + +namespace Connected.Expressions.Mappings; +internal sealed class EntityMapping +{ + private List _members; + public EntityMapping(Type entityType) + { + EntityType = entityType; + _members = new(); + + InitializeSchema(); + InitializeMembers(); + } + + public string Id => $"{Schema}.{Name}"; + public string Name { get; private set; } = default!; + public string Schema { get; private set; } = default!; + private Type EntityType { get; } + public ImmutableList Members => _members.ToImmutableList(); + + private void InitializeSchema() + { + var att = EntityType.ResolveTableAttribute(); + + if (string.IsNullOrWhiteSpace(att.Name)) + Name = EntityType.Name; + else + Name = att.Name; + + if (string.IsNullOrWhiteSpace(att.Schema)) + Schema = SchemaAttribute.DefaultSchema; + else + Schema = att.Schema; + } + + private void InitializeMembers() + { + var properties = Properties.GetImplementedProperties(EntityType); + + foreach (var property in properties) + { + var member = new MemberMapping(property); + + if (member.IsValid) + _members.Add(member); + } + + _members.SortByOrdinal(); + } + + public Expression CreateExpression(ExpressionCompilationContext context) + { + var tableAlias = Alias.New(); + var selectAlias = Alias.New(); + var table = new TableExpression(tableAlias, EntityType, Schema, Name); + var projector = CreateEntityExpression(context, table); + var pc = ColumnProjector.ProjectColumns(context.Language, projector, null, selectAlias, tableAlias); + + return new ProjectionExpression(new SelectExpression(selectAlias, pc.Columns, table, null), pc.Projector); + } + + private EntityExpression CreateEntityExpression(ExpressionCompilationContext context, Expression root) + { + var assignments = new List(); + + foreach (var member in Members) + { + if (CreateMemberExpression(context, root, member) is Expression memberExpression) + assignments.Add(new EntityAssignment(member, memberExpression)); + } + + return new EntityExpression(EntityType, CreateEntityExpression(assignments)); + } + + private Expression CreateMemberExpression(ExpressionCompilationContext context, Expression root, MemberMapping member) + { + if (root is AliasedExpression aliasedRoot) + { + return new ColumnExpression(Interop.Members.GetMemberType(member.MemberInfo), context.Language.TypeSystem.ResolveColumnType(member.Type), + aliasedRoot.Alias, member.Name); + } + + return Binder.Bind(root, member.MemberInfo); + } + + private Expression CreateEntityExpression(IList assignments) + { + NewExpression newExpression; + var readonlyMembers = assignments.Where(f => f.Mapping.IsReadOnly).ToArray(); + var cons = EntityType.GetTypeInfo().DeclaredConstructors.Where(c => c.IsPublic && !c.IsStatic).ToArray(); + var hasNoArgConstructor = cons.Any(c => c.GetParameters().Length == 0); + + if (readonlyMembers.Any() || !hasNoArgConstructor) + { + var consThatApply = cons.Select(c => BindConstructor(c, readonlyMembers)).Where(cbr => cbr is not null && !cbr.Remaining.Any()).ToList(); + + if (!consThatApply.Any()) + throw new InvalidOperationException($"Cannot construct type '{EntityType}' with all mapped and included members."); + + if (readonlyMembers.Length == assignments.Count) + return consThatApply[0].Expression; + + var r = BindConstructor(consThatApply[0].Expression.Constructor, assignments); + + newExpression = r.Expression; + assignments = r.Remaining; + } + else + newExpression = Expression.New(EntityType); + + Expression result; + + if (assignments.Any()) + { + if (EntityType.GetTypeInfo().IsInterface) + assignments = RemapAssignments(assignments, EntityType).ToList(); + + result = Expression.MemberInit(newExpression, assignments.Select(a => Expression.Bind(a.Mapping.MemberInfo, a.Expression)).ToArray()); + } + else + result = newExpression; + + return result; + } + + private ConstructorBindResult BindConstructor(ConstructorInfo cons, IList assignments) + { + var ps = cons.GetParameters(); + var args = new Expression[ps.Length]; + var mis = new MemberInfo[ps.Length]; + var members = new HashSet(assignments); + var used = new HashSet(); + + for (var i = 0; i < ps.Length; i++) + { + var p = ps[i]; + var assignment = members.FirstOrDefault(a => string.Equals(p.Name, a.Mapping.Name, StringComparison.OrdinalIgnoreCase) && p.ParameterType.IsAssignableFrom(a.Expression.Type)); + + if (assignment is null) + assignment = members.FirstOrDefault(a => string.Equals(p.Name, a.Mapping.Name, StringComparison.OrdinalIgnoreCase) && p.ParameterType.IsAssignableFrom(a.Expression.Type)); + + if (assignment is not null) + { + args[i] = assignment.Expression; + + if (mis is not null) + mis[i] = assignment.Mapping.MemberInfo; + + used.Add(assignment); + } + else + { + var mem = Members.Where(m => string.Equals(m.Name, p.Name, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); + + if (mem is not null) + { + args[i] = Expression.Constant(Types.GetDefault(p.ParameterType), p.ParameterType); + mis[i] = mem.MemberInfo; + } + else + return null; + } + } + + members.ExceptWith(used); + + return new ConstructorBindResult(Expression.New(cons, args, mis), members); + } + + private IEnumerable RemapAssignments(IEnumerable assignments, Type entityType) + { + foreach (var assign in assignments) + { + var member = Members.FirstOrDefault(f => string.Equals(f.Name, assign.Mapping.Name, StringComparison.Ordinal)); + + if (member is not null) + yield return new EntityAssignment(member, assign.Expression); + else + yield return assign; + } + } +} diff --git a/Connected.Expressions/Mappings/MappingsCache.cs b/Connected.Expressions/Mappings/MappingsCache.cs new file mode 100644 index 0000000..75063ee --- /dev/null +++ b/Connected.Expressions/Mappings/MappingsCache.cs @@ -0,0 +1,46 @@ +using System.Collections.Concurrent; +using System.Linq.Expressions; + +namespace Connected.Expressions.Mappings; +internal static class MappingsCache +{ + static MappingsCache() + { + Items = new(); + } + + private static ConcurrentDictionary Items { get; } + + public static EntityMapping Get(Type entityType) + { + if (entityType.FullName is not string key) + throw new ArgumentNullException(nameof(entityType.FullName)); + + if (Items.TryGetValue(key, out EntityMapping? existing)) + return existing; + + Items.TryAdd(key, new EntityMapping(entityType)); + + if (!Items.TryGetValue(key, out EntityMapping? result)) + throw new NullReferenceException(nameof(EntityMapping)); + + return result; + } + + public static bool CanEvaluateLocally(Expression expression) + { + if (expression is ConstantExpression cex) + { + if (cex.Value is IQueryable query && query.Provider.GetType() == typeof(IStorage<>)) + return false; + } + + if (expression is MethodCallExpression mc && (mc.Method.DeclaringType == typeof(Enumerable) || mc.Method.DeclaringType == typeof(Queryable))) + return false; + + if (expression.NodeType == ExpressionType.Convert && expression.Type == typeof(object)) + return true; + + return expression.NodeType != ExpressionType.Parameter && expression.NodeType != ExpressionType.Lambda; + } +} diff --git a/Connected.Expressions/Mappings/MemberMapping.cs b/Connected.Expressions/Mappings/MemberMapping.cs new file mode 100644 index 0000000..8582f12 --- /dev/null +++ b/Connected.Expressions/Mappings/MemberMapping.cs @@ -0,0 +1,42 @@ +using Connected.Entities.Annotations; +using System.Reflection; + +namespace Connected.Expressions.Mappings; + +internal sealed class MemberMapping +{ + public MemberMapping(PropertyInfo property) + { + Property = property; + + Initialize(); + } + + private PropertyInfo Property { get; } + + public MemberInfo MemberInfo => Property; + public bool IsValid { get; private set; } + public bool IsPrimaryKey { get; private set; } + public bool IsReadOnly { get; private set; } + public string Name { get; private set; } + + public Type Type => Property.PropertyType; + private void Initialize() + { + var persistence = Property.GetCustomAttribute(); + + if (persistence is null || persistence.Persistence.HasFlag(ColumnPersistence.Read)) + IsValid = true; + + IsReadOnly = persistence is not null && persistence.Persistence.HasFlag(ColumnPersistence.Read); + IsPrimaryKey = Property.GetCustomAttribute() is not null; + + var memberAttribute = Property.GetCustomAttribute(); + + if (memberAttribute is not null && !string.IsNullOrWhiteSpace(memberAttribute.Member)) + Name = memberAttribute.Member; + else + Name = MemberInfo.Name.ToCamelCase(); + } +} + diff --git a/Connected.Expressions/Query/EntityQuery.cs b/Connected.Expressions/Query/EntityQuery.cs new file mode 100644 index 0000000..d45e5b7 --- /dev/null +++ b/Connected.Expressions/Query/EntityQuery.cs @@ -0,0 +1,46 @@ +using System.Collections; +using System.Linq.Expressions; + +namespace Connected.Expressions.Query; + +internal sealed class EntityQuery : IQueryable, IAsyncEnumerable, IOrderedQueryable +{ + public EntityQuery(IQueryProvider provider, Expression expression) + { + Provider = provider; + Expression = expression; + } + + public Type ElementType => typeof(TEntity); + + public Expression Expression { get; } + + public IQueryProvider Provider { get; } + + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + var result = Provider.Execute(Expression); + + if (result is IEnumerable en) + { + var enumerator = en.GetEnumerator(); + + while (enumerator.MoveNext()) + { + await Task.CompletedTask; + + yield return (TEntity)enumerator.Current; + } + } + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)Provider.Execute(Expression)).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)Provider.Execute(Expression)).GetEnumerator(); + } +} diff --git a/Connected.Expressions/Query/QueryProvider.cs b/Connected.Expressions/Query/QueryProvider.cs new file mode 100644 index 0000000..b76a082 --- /dev/null +++ b/Connected.Expressions/Query/QueryProvider.cs @@ -0,0 +1,48 @@ +using Connected.Entities.Query; +using System.Linq.Expressions; + +namespace Connected.Expressions.Query; + +public abstract class QueryProvider : IAsyncQueryProvider +{ + public IQueryable CreateQuery(Expression expression) + { + var type = expression.Type.GetElementType(); + var generic = typeof(EntityQuery<>).MakeGenericType(new Type[] { type }); + + if (generic is null) + throw new NullReferenceException(nameof(type)); + + var instance = Activator.CreateInstance(generic, new object[] { this, expression }) as IQueryable; + + if (instance is null) + throw new NullReferenceException(nameof(type)); + + return instance; + } + + public IQueryable CreateQuery(Expression expression) + { + return new EntityQuery(this, expression); + } + + public object? Execute(Expression expression) + { + return OnExecute(expression); + } + + public object? Execute(Expression expression, CancellationToken cancellationToken = default) + { + return OnExecute(expression); + } + + public TResult Execute(Expression expression) + { + return (TResult)OnExecute(expression); + } + + protected virtual object? OnExecute(Expression expression) + { + return default; + } +} diff --git a/Connected.Expressions/Reflection/ReflectionExtensions.cs b/Connected.Expressions/Reflection/ReflectionExtensions.cs new file mode 100644 index 0000000..ab2d8a6 --- /dev/null +++ b/Connected.Expressions/Reflection/ReflectionExtensions.cs @@ -0,0 +1,72 @@ +using System.Reflection; +using Connected.Entities.Annotations; + +namespace Connected.Expressions.Reflection; + +internal static class ReflectionExtensions +{ + public static bool IsInQueryable(this MethodInfo method) + { + return method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(Enumerable); + } + + public static object? GetValue(this MemberInfo member, object instance) + { + var pi = member as PropertyInfo; + + if (pi is not null) + return pi.GetValue(instance, null); + + var fi = member as FieldInfo; + + if (fi is not null) + return fi.GetValue(instance); + + throw new InvalidOperationException(); + } + + public static void SetValue(this MemberInfo member, object instance, object value) + { + var pi = member as PropertyInfo; + + if (pi is not null) + { + pi.SetValue(instance, value, null); + + return; + } + + var fi = member as FieldInfo; + + if (fi is not null) + { + fi.SetValue(instance, value); + + return; + } + + throw new InvalidOperationException(); + } + + public static string MappingId(this Type type) + { + var att = type.ResolveTableAttribute(); + + return $"{att.Schema}.{att.Name}"; + } + + public static TableAttribute ResolveTableAttribute(this Type type) + { + var tableAttribute = type.GetCustomAttribute(); + + tableAttribute ??= new TableAttribute { Name = type.Name.ToCamelCase(), Schema = SchemaAttribute.DefaultSchema }; + + if (string.IsNullOrWhiteSpace(tableAttribute.Name)) + tableAttribute.Name = type.Name.ToCamelCase(); + + if (string.IsNullOrWhiteSpace(tableAttribute.Schema)) + tableAttribute.Schema = SchemaAttribute.DefaultSchema; + + return tableAttribute; + } +} diff --git a/Connected.Expressions/Serialization/DatabaseSerializer.cs b/Connected.Expressions/Serialization/DatabaseSerializer.cs new file mode 100644 index 0000000..f4697cb --- /dev/null +++ b/Connected.Expressions/Serialization/DatabaseSerializer.cs @@ -0,0 +1,244 @@ +using System.Linq.Expressions; +using Connected.Expressions.Evaluation; +using Connected.Expressions.Expressions; +using Connected.Expressions.Expressions.Serialization; +using Connected.Expressions.Formatters; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; + +namespace Connected.Expressions.Serialization; + +internal sealed class DatabaseSerializer : ExpressionSerializer +{ + public DatabaseSerializer(TextWriter writer, QueryLanguage? language) : base(writer) + { + Aliases = new(); + Language = language; + } + + private Dictionary Aliases { get; } + private QueryLanguage? Language { get; } + + public static new void Serialize(TextWriter writer, Expression expression) + { + Serialize(writer, expression, null); + } + + public static void Serialize(TextWriter writer, Expression expression, QueryLanguage? language) + { + new DatabaseSerializer(writer, language).Visit(expression); + } + + public new static string Serialize(Expression expression) + { + return Serialize((QueryLanguage?)null, expression); + } + + public static string Serialize(QueryLanguage? language, Expression expression) + { + var writer = new StringWriter(); + + Serialize(writer, expression, language); + + return writer.ToString(); + } + + protected override Expression? Visit(Expression? expression) + { + if (expression is null) + return default; + + switch ((DatabaseExpressionType)expression.NodeType) + { + case DatabaseExpressionType.Projection: + return VisitProjection((ProjectionExpression)expression); + case DatabaseExpressionType.ClientJoin: + return VisitClientJoin((ClientJoinExpression)expression); + case DatabaseExpressionType.Select: + return VisitSelect((SelectExpression)expression); + case DatabaseExpressionType.OuterJoined: + return VisitOuterJoined((OuterJoinedExpression)expression); + case DatabaseExpressionType.Column: + return VisitColumn((ColumnExpression)expression); + case DatabaseExpressionType.If: + case DatabaseExpressionType.Block: + case DatabaseExpressionType.Declaration: + return VisitCommand((CommandExpression)expression); + case DatabaseExpressionType.Batch: + return VisitBatch((BatchExpression)expression); + case DatabaseExpressionType.Function: + return VisitFunction((FunctionExpression)expression); + case DatabaseExpressionType.Entity: + return VisitEntity((EntityExpression)expression); + default: + if (expression is DatabaseExpression) + { + Write(FormatQuery(expression)); + + return expression; + } + else + return base.Visit(expression); + } + } + + private void AddAlias(Alias alias) + { + if (!Aliases.ContainsKey(alias)) + Aliases.Add(alias, Aliases.Count); + } + + private Expression VisitProjection(ProjectionExpression projection) + { + AddAlias(projection.Select.Alias); + Write("Project("); + WriteLine(Indentation.Inner); + Write("@\""); + Visit(projection.Select); + Write("\","); + WriteLine(Indentation.Same); + Visit(projection.Projector); + Write(','); + WriteLine(Indentation.Same); + Visit(projection.Aggregator); + WriteLine(Indentation.Outer); + Write(')'); + + return projection; + } + + private Expression VisitClientJoin(ClientJoinExpression join) + { + AddAlias(join.Projection.Select.Alias); + Write("ClientJoin("); + WriteLine(Indentation.Inner); + Write("OuterKey("); + VisitExpressionList(join.OuterKey); + Write("),"); + WriteLine(Indentation.Same); + Write("InnerKey("); + VisitExpressionList(join.InnerKey); + Write("),"); + WriteLine(Indentation.Same); + Visit(join.Projection); + WriteLine(Indentation.Outer); + Write(')'); + + return join; + } + + private Expression VisitOuterJoined(OuterJoinedExpression outer) + { + Write("Outer("); + WriteLine(Indentation.Inner); + Visit(outer.Test); + Write(", "); + WriteLine(Indentation.Same); + Visit(outer.Expression); + WriteLine(Indentation.Outer); + Write(')'); + + return outer; + } + + private Expression VisitSelect(SelectExpression select) + { + Write(select.QueryText); + + return select; + } + + private Expression VisitCommand(CommandExpression command) + { + Write(FormatQuery(command)); + + return command; + } + + private string FormatQuery(Expression query) + { + return SqlFormatter.Format(query); + } + + private Expression VisitBatch(BatchExpression batch) + { + Write("Batch("); + WriteLine(Indentation.Inner); + Visit(batch.Input); + Write(","); + WriteLine(Indentation.Same); + Visit(batch.Operation); + Write(","); + WriteLine(Indentation.Same); + Visit(batch.BatchSize); + Write(", "); + Visit(batch.Stream); + WriteLine(Indentation.Outer); + Write(")"); + + return batch; + } + + private Expression VisitVariable(VariableExpression vex) + { + Write(FormatQuery(vex)); + + return vex; + } + + private Expression VisitFunction(FunctionExpression function) + { + Write("FUNCTION "); + Write(function.Name); + + if (function.Arguments.Count > 0) + { + Write("("); + VisitExpressionList(function.Arguments); + Write(")"); + } + + return function; + } + + private Expression VisitEntity(EntityExpression expression) + { + Visit(expression.Expression); + + return expression; + } + + protected override Expression VisitConstant(ConstantExpression c) + { + if (c.Type == typeof(Command)) + { + var qc = (Command)c.Value; + + Write("new Connected.Expressions.Evaluation.QueryCommand {"); + WriteLine(Indentation.Inner); + Write("\"" + qc.CommandText + "\""); + Write(","); + WriteLine(Indentation.Same); + Visit(Expression.Constant(qc.Parameters)); + Write(")"); + WriteLine(Indentation.Outer); + + return c; + } + + return base.VisitConstant(c); + } + + private Expression VisitColumn(ColumnExpression column) + { + var aliasName = Aliases.TryGetValue(column.Alias, out int iAlias) ? "A" + iAlias : "A" + (column.Alias is not null ? column.Alias.GetHashCode().ToString() : "") + "?"; + + Write(aliasName); + Write("."); + Write("Column(\""); + Write(column.Name); + Write("\")"); + + return column; + } +} diff --git a/Connected.Expressions/Serialization/ExpressionSerializer.cs b/Connected.Expressions/Serialization/ExpressionSerializer.cs new file mode 100644 index 0000000..0718d23 --- /dev/null +++ b/Connected.Expressions/Serialization/ExpressionSerializer.cs @@ -0,0 +1,566 @@ +using System.Collections; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using Connected.Interop; +using ExpressionVisitor = Connected.Expressions.Visitors.ExpressionVisitor; + +namespace Connected.Expressions.Expressions.Serialization; + +internal enum Indentation +{ + Same = 0, + Inner = 1, + Outer = 2 +} + +internal class ExpressionSerializer : ExpressionVisitor +{ + private const char NewLine = '\n'; + private const char Space = ' '; + private const string Null = "null"; + static ExpressionSerializer() + { + Splitters = new char[] { '\n', '\r' }; + Special = new char[] { '\n', '\n', '\\' }; + } + protected ExpressionSerializer(TextWriter writer) + { + Writer = writer; + } + + protected int IndentationWidth { get; set; } = 2; + private int Depth { get; set; } + private TextWriter Writer { get; } + private static char[] Splitters { get; } + private static char[] Special { get; } + + public static void Serialize(TextWriter writer, Expression expression) + { + new ExpressionSerializer(writer).Visit(expression); + } + + public static string Serialize(Expression expression) + { + var writer = new StringWriter(); + + Serialize(writer, expression); + + return writer.ToString(); + } + + protected void WriteLine(Indentation style) + { + Writer.WriteLine(); + + Indent(style); + + for (var i = 0; i < Depth * IndentationWidth; i++) + Writer.Write(Space); + } + + protected void Write(char? text) + { + if (!text.HasValue) + return; + + Writer.Write(text.ToString()); + } + protected void Write(string? text) + { + if (string.IsNullOrEmpty(text)) + return; + + if (text.Contains(NewLine)) + { + var lines = text.Split(Splitters, StringSplitOptions.RemoveEmptyEntries); + var length = lines.Length; + + for (var i = 0; i < length; i++) + { + Write(lines[i]); + + if (i < length - 1) + WriteLine(Indentation.Same); + } + } + else + Writer.Write(text); + } + + protected void Indent(Indentation style) + { + if (style == Indentation.Inner) + Depth++; + else if (style == Indentation.Outer) + { + Depth--; + + System.Diagnostics.Debug.Assert(Depth >= 0); + } + } + + protected virtual string? ResolveOperator(ExpressionType type) + { + return type switch + { + ExpressionType.Not => "!", + ExpressionType.Add or ExpressionType.AddChecked => "+", + ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked => "-", + ExpressionType.Multiply or ExpressionType.MultiplyChecked => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.And => "&", + ExpressionType.AndAlso => "&&", + ExpressionType.Or => "|", + ExpressionType.OrElse => "||", + ExpressionType.LessThan => "<", + ExpressionType.LessThanOrEqual => "<=", + ExpressionType.GreaterThan => ">", + ExpressionType.GreaterThanOrEqual => ">=", + ExpressionType.Equal => "==", + ExpressionType.NotEqual => "!=", + ExpressionType.Coalesce => "??", + ExpressionType.RightShift => ">>", + ExpressionType.LeftShift => "<<", + ExpressionType.ExclusiveOr => "^", + _ => null, + }; + } + + protected override Expression VisitBinary(BinaryExpression expression) + { + switch (expression.NodeType) + { + case ExpressionType.ArrayIndex: + Visit(expression.Left); + Write("["); + Visit(expression.Right); + Write("]"); + break; + case ExpressionType.Power: + Write("POW("); + Visit(expression.Left); + Write(", "); + Visit(expression.Right); + Write(")"); + break; + default: + Visit(expression.Left); + Write(Space); + Write(ResolveOperator(expression.NodeType)); + Write(Space); + Visit(expression.Right); + break; + } + + return expression; + } + + protected override Expression VisitUnary(UnaryExpression expression) + { + switch (expression.NodeType) + { + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + Write("(("); + Write(GetTypeName(expression.Type)); + Write(")"); + Visit(expression.Operand); + Write(")"); + break; + case ExpressionType.ArrayLength: + Visit(expression.Operand); + Write(".Length"); + break; + case ExpressionType.Quote: + Visit(expression.Operand); + break; + case ExpressionType.TypeAs: + Visit(expression.Operand); + Write(" as "); + Write(GetTypeName(expression.Type)); + break; + case ExpressionType.UnaryPlus: + Visit(expression.Operand); + break; + default: + Write(ResolveOperator(expression.NodeType)); + Visit(expression.Operand); + break; + } + + return expression; + } + + protected virtual string GetTypeName(Type type) + { + var name = type.Name.Replace('+', '.'); + var iGeneneric = name.IndexOf('`'); + + if (iGeneneric > 0) + name = name[..iGeneneric]; + + var info = type.GetTypeInfo(); + + if (info.IsGenericType || info.IsGenericTypeDefinition) + { + var sb = new StringBuilder(); + + sb.Append(name); + sb.Append('<'); + + var args = info.GenericTypeArguments; + + for (int i = 0; i < args.Length; i++) + { + if (i > 0) + sb.Append(','); + + if (info.IsGenericType) + sb.Append(GetTypeName(args[i])); + } + + sb.Append('>'); + + name = sb.ToString(); + } + + return name; + } + + protected override Expression VisitConditional(ConditionalExpression expression) + { + Visit(expression.Test); + WriteLine(Indentation.Inner); + Write("? "); + Visit(expression.IfTrue); + WriteLine(Indentation.Same); + Write(": "); + Visit(expression.IfFalse); + Indent(Indentation.Outer); + + return expression; + } + + protected override IEnumerable VisitBindingList(ReadOnlyCollection bindings) + { + var length = bindings.Count; + + for (var i = 0; i < length; i++) + { + VisitBinding(bindings[i]); + + if (i < length - 1) + { + Write(','); + WriteLine(Indentation.Same); + } + } + + return bindings; + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (expression.Value is null) + Write(Null); + else if (expression.Type == typeof(string)) + { + if (expression.Value.ToString() is string value) + { + if (value.IndexOfAny(Special) >= 0) + Write('@'); + + Write('"'); + Write(expression.Value.ToString()); + Write('"'); + } + } + else if (expression.Type == typeof(DateTime)) + { + Write("new DateTime(\""); + Write(expression.Value.ToString()); + Write("\")"); + } + else if (expression.Type.IsArray) + { + if (expression.Type.GetElementType() is Type elementType) + VisitNewArray(Expression.NewArrayInit(elementType, ((IEnumerable)expression.Value).OfType().Select(v => (Expression)Expression.Constant(v, elementType)))); + } + else + Write(expression.Value.ToString()); + + return expression; + } + + protected override ElementInit VisitElementInitializer(ElementInit initializer) + { + if (initializer.Arguments.Count > 1) + { + Write('{'); + + var length = initializer.Arguments.Count; + + for (var i = 0; i < length; i++) + { + Visit(initializer.Arguments[i]); + + if (i < length - 1) + Write(", "); + } + + Write('}'); + } + else + Visit(initializer.Arguments[0]); + + return initializer; + } + + protected override IEnumerable VisitElementInitializerList(ReadOnlyCollection original) + { + var length = original.Count; + + for (var i = 0; i < length; i++) + { + VisitElementInitializer(original[i]); + + if (i < length - 1) + { + Write(','); + WriteLine(Indentation.Same); + } + } + + return original; + } + + protected override ReadOnlyCollection VisitExpressionList(ReadOnlyCollection original) + { + var length = original.Count; + + for (var i = 0; i < length; i++) + { + Visit(original[i]); + + if (i < length - 1) + { + Write(','); + WriteLine(Indentation.Same); + } + } + + return original; + } + + protected override Expression VisitInvocation(InvocationExpression expression) + { + Write("Invoke("); + WriteLine(Indentation.Inner); + VisitExpressionList(expression.Arguments); + Write(", "); + WriteLine(Indentation.Same); + Visit(expression.Expression); + WriteLine(Indentation.Same); + Write(')'); + Indent(Indentation.Outer); + + return expression; + } + + protected override Expression VisitLambda(LambdaExpression lambda) + { + if (lambda.Parameters.Count != 1) + { + Write('('); + + var length = lambda.Parameters.Count; + + for (var i = 0; i < length; i++) + { + Write(lambda.Parameters[i].Name); + + if (i < length - 1) + Write(", "); + } + + Write(')'); + } + else + Write(lambda.Parameters[0].Name); + + Write(" => "); + Visit(lambda.Body); + + return lambda; + } + + protected override Expression VisitListInit(ListInitExpression expression) + { + Visit(expression.NewExpression); + Write(" {"); + WriteLine(Indentation.Inner); + VisitElementInitializerList(expression.Initializers); + WriteLine(Indentation.Outer); + Write('}'); + + return expression; + } + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + Visit(expression.Expression); + Write('.'); + Write(expression.Member.Name); + + return expression; + } + + protected override MemberAssignment VisitMemberAssignment(MemberAssignment assignment) + { + Write(assignment.Member.Name); + Write(" = "); + Visit(assignment.Expression); + + return assignment; + } + + protected override Expression VisitMemberInit(MemberInitExpression expression) + { + Visit(expression.NewExpression); + Write(" {"); + WriteLine(Indentation.Inner); + VisitBindingList(expression.Bindings); + WriteLine(Indentation.Outer); + Write('}'); + + return expression; + } + + protected override MemberListBinding VisitMemberListBinding(MemberListBinding binding) + { + Write(binding.Member.Name); + Write(" = {"); + WriteLine(Indentation.Inner); + VisitElementInitializerList(binding.Initializers); + WriteLine(Indentation.Outer); + Write('}'); + + return binding; + } + + protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) + { + Write(binding.Member.Name); + Write(" = {"); + WriteLine(Indentation.Inner); + VisitBindingList(binding.Bindings); + WriteLine(Indentation.Outer); + Write('}'); + + return binding; + } + + protected override Expression VisitMethodCall(MethodCallExpression expression) + { + if (expression.Object is not null) + Visit(expression.Object); + else + { + if (expression.Method.DeclaringType is null) + throw new NullReferenceException(nameof(expression.Method.DeclaringType)); + + Write(GetTypeName(expression.Method.DeclaringType)); + } + + Write('.'); + Write(expression.Method.Name); + Write('('); + + if (expression.Arguments.Count > 1) + WriteLine(Indentation.Inner); + + VisitExpressionList(expression.Arguments); + + if (expression.Arguments.Count > 1) + WriteLine(Indentation.Outer); + + Write(')'); + + return expression; + } + + protected override NewExpression VisitNew(NewExpression expression) + { + if (expression.Constructor?.DeclaringType is null) + throw new NullReferenceException(nameof(expression.Constructor.DeclaringType)); + + Write("new "); + Write(GetTypeName(expression.Constructor.DeclaringType)); + Write('('); + + if (expression.Arguments.Count > 1) + WriteLine(Indentation.Inner); + + VisitExpressionList(expression.Arguments); + + if (expression.Arguments.Count > 1) + WriteLine(Indentation.Outer); + + Write(')'); + + return expression; + } + + protected override Expression VisitNewArray(NewArrayExpression expression) + { + if (Enumerables.GetEnumerableElementType(expression.Type) is not Type enumerableType) + throw new NullReferenceException(nameof(enumerableType)); + + Write("new "); + Write(GetTypeName(enumerableType)); + Write("[] {"); + + if (expression.Expressions.Count > 1) + WriteLine(Indentation.Inner); + + VisitExpressionList(expression.Expressions); + + if (expression.Expressions.Count > 1) + WriteLine(Indentation.Outer); + + Write('}'); + + return expression; + } + + protected override Expression VisitParameter(ParameterExpression expression) + { + Write(expression.Name); + + return expression; + } + + protected override Expression VisitTypeIs(TypeBinaryExpression expression) + { + Visit(expression.Expression); + Write(" is "); + Write(GetTypeName(expression.TypeOperand)); + + return expression; + } + + protected override Expression VisitUnknown(Expression expression) + { + Write(expression.ToString()); + + return expression; + } + + protected override void OnDisposing() + { + Writer?.Dispose(); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/AggregateChecker.cs b/Connected.Expressions/Translation/AggregateChecker.cs new file mode 100644 index 0000000..fe38e45 --- /dev/null +++ b/Connected.Expressions/Translation/AggregateChecker.cs @@ -0,0 +1,42 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation; + +public sealed class AggregateChecker : DatabaseVisitor +{ + private AggregateChecker() + { + + } + + private bool HasAggregate { get; set; } + + public static bool HasAggregates(SelectExpression expression) + { + var checker = new AggregateChecker(); + + checker.Visit(expression); + + return checker.HasAggregate; + } + + protected override Expression VisitAggregate(AggregateExpression aggregate) + { + HasAggregate = true; + + return aggregate; + } + + protected override Expression VisitSelect(SelectExpression select) + { + Visit(select.Where); + VisitOrderBy(select.OrderBy); + VisitColumnDeclarations(select.Columns); + + return select; + } + + protected override Expression VisitSubquery(SubqueryExpression subquery) => subquery; +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Aggregator.cs b/Connected.Expressions/Translation/Aggregator.cs new file mode 100644 index 0000000..be32e1d --- /dev/null +++ b/Connected.Expressions/Translation/Aggregator.cs @@ -0,0 +1,63 @@ +using Connected.Expressions.Collections; +using Connected.Interop; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Expressions.Translation; + +public static class Aggregator +{ + public static LambdaExpression? GetAggregator(Type expectedType, Type actualType) + { + var actualElementType = Enumerables.GetEnumerableElementType(actualType); + + if (!expectedType.GetTypeInfo().IsAssignableFrom(actualType.GetTypeInfo())) + { + var expectedElementType = Enumerables.GetEnumerableElementType(expectedType); + var p = Expression.Parameter(actualType, "p"); + Expression? body = null; + + if (expectedType.GetTypeInfo().IsAssignableFrom(actualElementType.GetTypeInfo())) + body = Expression.Call(typeof(Enumerable), "SingleOrDefault", new Type[] { actualElementType }, p); + else if (expectedType.GetTypeInfo().IsGenericType && (expectedType == typeof(IQueryable) || expectedType == typeof(IOrderedQueryable) || expectedType.GetGenericTypeDefinition() == typeof(IQueryable<>) || expectedType.GetGenericTypeDefinition() == typeof(IOrderedQueryable<>))) + { + body = Expression.Call(typeof(Queryable), "AsQueryable", new Type[] { expectedElementType }, CoerceElement(expectedElementType, p)); + + if (body.Type != expectedType) + body = Expression.Convert(body, expectedType); + } + else if (expectedType.IsArray && expectedType.GetArrayRank() == 1) + body = Expression.Call(typeof(Enumerable), "ToArray", new Type[] { expectedElementType }, CoerceElement(expectedElementType, p)); + else if (expectedType.GetTypeInfo().IsGenericType && expectedType.GetGenericTypeDefinition().GetTypeInfo().IsAssignableFrom(typeof(IList<>).GetTypeInfo())) + { + var gt = typeof(DeferredList<>).MakeGenericType(expectedType.GetTypeInfo().GenericTypeArguments); + var cn = Types.FindConstructor(gt, new Type[] { typeof(IEnumerable<>).MakeGenericType(expectedType.GetTypeInfo().GenericTypeArguments) }); + + body = Expression.New(cn, CoerceElement(expectedElementType, p)); + } + else if (expectedType.GetTypeInfo().IsAssignableFrom(typeof(List<>).MakeGenericType(actualElementType).GetTypeInfo())) + body = Expression.Call(typeof(Enumerable), "ToList", new Type[] { expectedElementType }, CoerceElement(expectedElementType, p)); + else + { + var ci = Types.FindConstructor(expectedType, new Type[] { actualType }); + + if (ci is not null) + body = Expression.New(ci, p); + } + if (body is not null) + return Expression.Lambda(body, p); + } + + return null; + } + + private static Expression CoerceElement(Type expectedElementType, Expression expression) + { + var elementType = Enumerables.GetEnumerableElementType(expression.Type); + + if (expectedElementType != elementType && (expectedElementType.GetTypeInfo().IsAssignableFrom(elementType.GetTypeInfo()) || elementType.GetTypeInfo().IsAssignableFrom(expectedElementType.GetTypeInfo()))) + return Expression.Call(typeof(Enumerable), "Cast", new Type[] { expectedElementType }, expression); + + return expression; + } +} diff --git a/Connected.Expressions/Translation/Alias.cs b/Connected.Expressions/Translation/Alias.cs new file mode 100644 index 0000000..985dcee --- /dev/null +++ b/Connected.Expressions/Translation/Alias.cs @@ -0,0 +1,15 @@ +namespace Connected.Expressions.Translation; + +public sealed class Alias +{ + private Alias() { } + public override string ToString() + { + return $"A:{GetHashCode()}"; + } + + public static Alias New() + { + return new Alias(); + } +} diff --git a/Connected.Expressions/Translation/Binder.cs b/Connected.Expressions/Translation/Binder.cs new file mode 100644 index 0000000..099982f --- /dev/null +++ b/Connected.Expressions/Translation/Binder.cs @@ -0,0 +1,1011 @@ +using Connected.Expressions.Evaluation; +using Connected.Expressions.Expressions; +using Connected.Expressions.Mappings; +using Connected.Expressions.Reflection; +using Connected.Expressions.Translation.Projections; +using Connected.Expressions.Visitors; +using Connected.Interop; +using System.Collections; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Expressions.Translation; + +public sealed class Binder : DatabaseVisitor +{ + private Binder(ExpressionCompilationContext context, Expression expression) + { + Context = context; + Expression = expression; + + ParameterMapping = new(); + GroupByMap = new(); + } + + private ExpressionCompilationContext Context { get; } + private Dictionary ParameterMapping { get; } + private Dictionary GroupByMap { get; } + private List? ThenBys { get; set; } + private Expression CurrentGroupElement { get; set; } + private Expression Expression { get; set; } + public static Expression? Bind(ExpressionCompilationContext context, Expression expression) + { + return new Binder(context, expression).Visit(expression); + } + + public static Expression Bind(Expression source, MemberInfo member) + { + switch (source.NodeType) + { + case (ExpressionType)DatabaseExpressionType.Entity: + var ex = (EntityExpression)source; + var result = Bind(ex.Expression, member); + var mex = result as MemberExpression; + + if (mex is not null && mex.Expression == ex.Expression && mex.Member == member) + return Expression.MakeMemberAccess(source, member); + + return result; + case ExpressionType.Convert: + var ux = (UnaryExpression)source; + + return Bind(ux.Operand, member); + case ExpressionType.MemberInit: + var min = (MemberInitExpression)source; + + for (var i = 0; i < min.Bindings.Count; i++) + { + var assign = min.Bindings[i] as MemberAssignment; + + if (assign is not null && MembersMatch(assign.Member, member)) + return assign.Expression; + } + + break; + case ExpressionType.New: + var nex = (NewExpression)source; + + if (nex.Members is not null) + { + for (var i = 0; i < nex.Members.Count; i++) + { + if (MembersMatch(nex.Members[i], member)) + return nex.Arguments[i]; + } + } + else if (nex.Type.GetTypeInfo().IsGenericType && nex.Type.GetGenericTypeDefinition() == typeof(Grouping<,>)) + { + if (string.Equals(member.Name, "Key", StringComparison.Ordinal)) + return nex.Arguments[0]; + } + + break; + case (ExpressionType)DatabaseExpressionType.Projection: + var proj = (ProjectionExpression)source; + var newProjector = Bind(proj.Projector, member); + var mt = Members.GetMemberType(member); + + return new ProjectionExpression(proj.Select, newProjector, Aggregator.GetAggregator(mt, typeof(IEnumerable<>).MakeGenericType(mt))); + + case (ExpressionType)DatabaseExpressionType.OuterJoined: + var oj = (OuterJoinedExpression)source; + var em = Bind(oj.Expression, member); + + if (em is ColumnExpression) + return em; + + return new OuterJoinedExpression(oj.Test, em); + case ExpressionType.Conditional: + var cex = (ConditionalExpression)source; + + return Expression.Condition(cex.Test, Bind(cex.IfTrue, member), Bind(cex.IfFalse, member)); + case ExpressionType.Constant: + var con = (ConstantExpression)source; + var memberType = Members.GetMemberType(member); + + if (con.Value is null) + return Expression.Constant(GetDefault(memberType), memberType); + else + return Expression.Constant(GetValue(con.Value, member), memberType); + } + + return Expression.MakeMemberAccess(source, member); + } + + protected override Expression? VisitMethodCall(MethodCallExpression expression) + { + if (expression.Method.IsInQueryable()) + { + switch (expression.Method.Name) + { + case "Where": + return BindWhere(expression.Arguments[0], GetLambda(expression.Arguments[1])); + case "Select": + return BindSelect(expression.Arguments[0], GetLambda(expression.Arguments[1])); + case "SelectMany": + + if (expression.Arguments.Count == 2) + return BindSelectMany(expression.Arguments[0], GetLambda(expression.Arguments[1]), null); + else if (expression.Arguments.Count == 3) + return BindSelectMany(expression.Arguments[0], GetLambda(expression.Arguments[1]), GetLambda(expression.Arguments[2])); + + break; + case "Join": + + return BindJoin(expression.Arguments[0], expression.Arguments[1], GetLambda(expression.Arguments[2]), + GetLambda(expression.Arguments[3]), GetLambda(expression.Arguments[4])); + + case "GroupJoin": + + if (expression.Arguments.Count == 5) + { + return BindGroupJoin(expression.Method, expression.Arguments[0], expression.Arguments[1], GetLambda(expression.Arguments[2]), + GetLambda(expression.Arguments[3]), GetLambda(expression.Arguments[4])); + } + + break; + case "OrderBy": + return BindOrderBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), OrderType.Ascending); + case "OrderByDescending": + return BindOrderBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), OrderType.Descending); + case "ThenBy": + return BindThenBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), OrderType.Ascending); + case "ThenByDescending": + return BindThenBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), OrderType.Descending); + case "GroupBy": + + if (expression.Arguments.Count == 2) + return BindGroupBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), null, null); + else if (expression.Arguments.Count == 3) + { + var lambda1 = GetLambda(expression.Arguments[1]); + var lambda2 = GetLambda(expression.Arguments[2]); + + if (lambda2.Parameters.Count == 1) + return BindGroupBy(expression.Arguments[0], lambda1, lambda2, null); + else if (lambda2.Parameters.Count == 2) + return BindGroupBy(expression.Arguments[0], lambda1, null, lambda2); + } + else if (expression.Arguments.Count == 4) + return BindGroupBy(expression.Arguments[0], GetLambda(expression.Arguments[1]), GetLambda(expression.Arguments[2]), GetLambda(expression.Arguments[3])); + + break; + case "Distinct": + + if (expression.Arguments.Count == 1) + return BindDistinct(expression.Arguments[0]); + + break; + case "Skip": + + if (expression.Arguments.Count == 2) + return BindSkip(expression.Arguments[0], expression.Arguments[1]); + + break; + case "Take": + + if (expression.Arguments.Count == 2) + return BindTake(expression.Arguments[0], expression.Arguments[1]); + + break; + case "First": + case "FirstOrDefault": + case "Single": + case "SingleOrDefault": + case "Last": + case "LastOrDefault": + + if (expression.Arguments.Count == 1) + return BindFirst(expression.Arguments[0], null, expression.Method.Name, expression == Expression); + else if (expression.Arguments.Count == 2) + return BindFirst(expression.Arguments[0], GetLambda(expression.Arguments[1]), expression.Method.Name, expression == Expression); + + break; + case "Any": + + if (expression.Arguments.Count == 1) + return BindAnyAll(expression.Arguments[0], expression.Method, null, expression == Expression); + else if (expression.Arguments.Count == 2) + return BindAnyAll(expression.Arguments[0], expression.Method, GetLambda(expression.Arguments[1]), expression == Expression); + + break; + case "All": + if (expression.Arguments.Count == 2) + return BindAnyAll(expression.Arguments[0], expression.Method, GetLambda(expression.Arguments[1]), expression == Expression); + break; + case "Contains": + if (expression.Arguments.Count == 2) + return BindContains(expression.Arguments[0], expression.Arguments[1], expression == Expression); + break; + case "Cast": + if (expression.Arguments.Count == 1) + return BindCast(expression.Arguments[0], expression.Method.GetGenericArguments()[0]); + break; + case "Reverse": + return BindReverse(expression.Arguments[0]); + case "Intersect": + case "Except": + if (expression.Arguments.Count == 2) + return BindIntersect(expression.Arguments[0], expression.Arguments[1], expression.Method.Name == "Except"); + break; + } + } + + if (Context.Language.IsAggregate(expression.Method)) + { + var lambda = expression.Arguments.Count > 1 ? GetLambda(expression.Arguments[1]) : null; + return BindAggregate(expression.Arguments[0], expression.Method.Name, expression.Method.ReturnType, lambda, expression == Expression); + } + + return base.VisitMethodCall(expression); + } + + private Expression BindAggregate(Expression expression, string aggregateName, Type returnType, LambdaExpression? argument, bool isRoot) + { + var hasPredicateArg = Context.Language.IsAggregateArgumentPredicate(aggregateName); + var isDistinct = false; + var argumentWasPredicate = false; + var useAlternateArg = false; + var methodCall = expression as MethodCallExpression; + + if (methodCall is not null && !hasPredicateArg && argument is null) + { + if (string.Equals(methodCall.Method.Name, "Distinct", StringComparison.Ordinal) && methodCall.Arguments.Count == 1 && + methodCall.Method.IsInQueryable() && Context.Language.AllowDistinctInAggregates) + { + expression = methodCall.Arguments[0]; + isDistinct = true; + } + } + + if (argument is not null && hasPredicateArg) + { + var enType = expression.Type.GetEnumerableElementType(); + expression = Expression.Call(typeof(Queryable), "Where", enType is null ? null : new[] { enType }, expression, argument); + argument = null; + argumentWasPredicate = true; + } + + var projection = VisitSequence(expression); + Expression? argExpr = null; + + if (argument is not null) + { + ParameterMapping[argument.Parameters[0]] = projection.Projector; + argExpr = Visit(argument.Body); + } + else if (!hasPredicateArg || useAlternateArg) + argExpr = projection.Projector; + + var alias = Alias.New(); + + ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + var aggExpr = new AggregateExpression(returnType, aggregateName, argExpr, isDistinct); + var colType = Context.Language.TypeSystem.ResolveColumnType(returnType); + var select = new SelectExpression(alias, new ColumnDeclaration[] { new ColumnDeclaration(string.Empty, aggExpr, colType) }, projection.Select, null); + + if (isRoot) + { + var p = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(aggExpr.Type), "p"); + var gator = Expression.Lambda(Expression.Call(typeof(Enumerable), "Single", new Type[] { returnType }, p), p); + + return new ProjectionExpression(select, new ColumnExpression(returnType, Context.Language.TypeSystem.ResolveColumnType(returnType), alias, ""), gator); + } + + var subquery = new ScalarExpression(returnType, select); + + if (!argumentWasPredicate && GroupByMap.TryGetValue(projection, out GroupByDescriptor? info)) + { + if (argument is not null) + { + ParameterMapping[argument.Parameters[0]] = info.Element; + argExpr = Visit(argument.Body); + } + else if (!hasPredicateArg || useAlternateArg) + argExpr = info.Element; + + if (aggExpr is not null) + aggExpr = new AggregateExpression(returnType, aggregateName, argExpr, isDistinct); + + if (projection == CurrentGroupElement) + return aggExpr; + + return new AggregateSubqueryExpression(info.Alias, aggExpr, subquery); + } + + return subquery; + } + + private static LambdaExpression GetLambda(Expression expression) + { + while (expression.NodeType == ExpressionType.Quote) + expression = ((UnaryExpression)expression).Operand; + + if (expression.NodeType == ExpressionType.Constant) + { + if (expression is not ConstantExpression constantExpression) + throw new InvalidCastException(nameof(ConstantExpression)); + + if (constantExpression.Value is not LambdaExpression lambdaExpression) + throw new InvalidCastException(nameof(LambdaExpression)); + + return lambdaExpression; + } + + if (expression is not LambdaExpression lambda) + throw new InvalidCastException(nameof(LambdaExpression)); + + return lambda; + } + + private Expression BindWhere(Expression source, LambdaExpression predicate) + { + var projection = VisitSequence(source); + + ParameterMapping[predicate.Parameters[0]] = projection.Projector; + + var where = Visit(predicate.Body); + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, where), pc.Projector); + } + + private ProjectionExpression VisitSequence(Expression source) => ConvertToSequence(Visit(source)); + + private static ProjectionExpression ConvertToSequence(Expression expression) + { + switch (expression.NodeType) + { + case (ExpressionType)DatabaseExpressionType.Projection: + return (ProjectionExpression)expression; + case ExpressionType.New: + var nex = (NewExpression)expression; + + if (expression.Type.GetTypeInfo().IsGenericType && expression.Type.GetGenericTypeDefinition() == typeof(Grouping<,>)) + return (ProjectionExpression)nex.Arguments[1]; + + goto default; + case ExpressionType.MemberAccess: + + if (expression.NodeType != ExpressionType.MemberAccess) + return ConvertToSequence(expression); + + goto default; + default: + + if (GetNewExpression(expression) is Expression newExpression) + { + expression = newExpression; + + goto case ExpressionType.New; + } + + throw new NotSupportedException($"The expression of type '{expression.Type}' is not a sequence"); + } + } + + private static NewExpression? GetNewExpression(Expression expression) + { + while (expression.NodeType == ExpressionType.Convert || expression.NodeType == ExpressionType.ConvertChecked) + expression = ((UnaryExpression)expression).Operand; + + return expression as NewExpression; + } + + private Expression BindSelect(Expression source, LambdaExpression selector) + { + var projection = VisitSequence(source); + + ParameterMapping[selector.Parameters[0]] = projection.Projector; + + var expression = Visit(selector.Body); + var alias = Alias.New(); + var pc = ProjectColumns(expression, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null), pc.Projector); + } + + private ProjectedColumns ProjectColumns(Expression expression, Alias alias, params Alias[] existingAliases) + { + return ColumnProjector.ProjectColumns(Context.Language, expression, null, alias, existingAliases); + } + + private Expression BindSelectMany(Expression source, LambdaExpression collectionSelector, LambdaExpression resultSelector) + { + var projection = VisitSequence(source); + + ParameterMapping[collectionSelector.Parameters[0]] = projection.Projector; + + var collection = collectionSelector.Body; + var defaultIfEmpty = false; + var mcs = collection as MethodCallExpression; + + if (mcs is not null && string.Equals(mcs.Method.Name, "DefaultIfEmpty", StringComparison.Ordinal) && mcs.Arguments.Count == 1 && mcs.Method.IsInQueryable()) + { + collection = mcs.Arguments[0]; + defaultIfEmpty = true; + } + + var collectionProjection = VisitSequence(collection); + var isTable = collectionProjection.Select.From is TableExpression; + var joinType = isTable ? JoinType.CrossJoin : defaultIfEmpty ? JoinType.OuterApply : JoinType.CrossApply; + + if (joinType == JoinType.OuterApply) + collectionProjection = Context.Language.AddOuterJoinTest(collectionProjection); + + var join = new JoinExpression(joinType, projection.Select, collectionProjection.Select, null); + var alias = Alias.New(); + ProjectedColumns pc; + + if (resultSelector is null) + pc = ProjectColumns(collectionProjection.Projector, alias, projection.Select.Alias, collectionProjection.Select.Alias); + else + { + ParameterMapping[resultSelector.Parameters[0]] = projection.Projector; + ParameterMapping[resultSelector.Parameters[1]] = collectionProjection.Projector; + + var result = Visit(resultSelector.Body); + + pc = ProjectColumns(result, alias, projection.Select.Alias, collectionProjection.Select.Alias); + } + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, join, null), pc.Projector); + } + + private Expression BindJoin(Expression outerSource, Expression innerSource, LambdaExpression outerKey, LambdaExpression innerKey, LambdaExpression resultSelector) + { + if (VisitSequence(outerSource) is not ProjectionExpression outerProjection) + throw new NullReferenceException(nameof(outerProjection)); + + if (VisitSequence(innerSource) is not ProjectionExpression innerProjection) + throw new NullReferenceException(nameof(innerProjection)); + + ParameterMapping[outerKey.Parameters[0]] = outerProjection.Projector; + + var outerKeyExpr = Visit(outerKey.Body); + + ParameterMapping[innerKey.Parameters[0]] = innerProjection.Projector; + + var innerKeyExpr = Visit(innerKey.Body); + + ParameterMapping[resultSelector.Parameters[0]] = outerProjection.Projector; + ParameterMapping[resultSelector.Parameters[1]] = innerProjection.Projector; + + if (Visit(resultSelector.Body) is not Expression resultExpression) + throw new NullReferenceException(nameof(resultExpression)); + + var join = new JoinExpression(JoinType.InnerJoin, outerProjection.Select, innerProjection.Select, outerKeyExpr.Equal(innerKeyExpr)); + var alias = Alias.New(); + var pc = ProjectColumns(resultExpression, alias, outerProjection.Select.Alias, innerProjection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, join, null), pc.Projector); + } + + private Expression BindGroupJoin(MethodInfo groupJoinMethod, Expression outerSource, Expression innerSource, LambdaExpression outerKey, + LambdaExpression innerKey, LambdaExpression resultSelector) + { + /* + * A database will treat this no differently than a SelectMany w/ result selector, so just use that translation instead + */ + var args = groupJoinMethod.GetGenericArguments(); + var outerProjection = VisitSequence(outerSource); + + ParameterMapping[outerKey.Parameters[0]] = outerProjection.Projector; + + var predicateLambda = Expression.Lambda(innerKey.Body.Equal(outerKey.Body), innerKey.Parameters[0]); + var callToWhere = Expression.Call(typeof(Enumerable), "Where", new Type[] { args[1] }, innerSource, predicateLambda); + var group = Visit(callToWhere); + + ParameterMapping[resultSelector.Parameters[0]] = outerProjection.Projector; + + if (group is not null) + ParameterMapping[resultSelector.Parameters[1]] = group; + + var resultExpr = Visit(resultSelector.Body); + var alias = Alias.New(); + var pc = ProjectColumns(resultExpr, alias, outerProjection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, outerProjection.Select, null), pc.Projector); + } + + private Expression BindOrderBy(Expression source, LambdaExpression orderSelector, OrderType orderType) + { + var myThenBys = ThenBys; + + ThenBys = null; + + var projection = VisitSequence(source); + + ParameterMapping[orderSelector.Parameters[0]] = projection.Projector; + + var orderings = new List { new OrderExpression(orderType, Visit(orderSelector.Body)) }; + + if (myThenBys is not null) + { + for (var i = myThenBys.Count - 1; i >= 0; i--) + { + var tb = myThenBys[i]; + var lambda = (LambdaExpression)tb.Expression; + + ParameterMapping[lambda.Parameters[0]] = projection.Projector; + + orderings.Add(new OrderExpression(tb.OrderType, Visit(lambda.Body))); + } + } + + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null, orderings.AsReadOnly(), null), pc.Projector); + } + + private Expression BindThenBy(Expression source, LambdaExpression orderSelector, OrderType orderType) + { + ThenBys ??= new List(); + + ThenBys.Add(new OrderExpression(orderType, orderSelector)); + + return Visit(source); + } + + private Expression BindGroupBy(Expression source, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector) + { + var projection = VisitSequence(source); + + ParameterMapping[keySelector.Parameters[0]] = projection.Projector; + + var keyExpr = Visit(keySelector.Body); + var elemExpr = projection.Projector; + + if (elementSelector is not null) + { + ParameterMapping[elementSelector.Parameters[0]] = projection.Projector; + + elemExpr = Visit(elementSelector.Body); + } + + var keyProjection = ProjectColumns(keyExpr, projection.Select.Alias, projection.Select.Alias); + var groupExprs = keyProjection.Columns.Select(c => c.Expression).ToArray(); + var subqueryBasis = VisitSequence(source); + + ParameterMapping[keySelector.Parameters[0]] = subqueryBasis.Projector; + + var subqueryKey = Visit(keySelector.Body); + var subqueryKeyPC = ProjectColumns(subqueryKey, subqueryBasis.Select.Alias, subqueryBasis.Select.Alias); + var subqueryGroupExprs = subqueryKeyPC.Columns.Select(c => c.Expression).ToArray(); + var subqueryCorrelation = BuildPredicateWithNullsEqual(subqueryGroupExprs, groupExprs); + var subqueryElemExpr = subqueryBasis.Projector; + + if (elementSelector is not null) + { + ParameterMapping[elementSelector.Parameters[0]] = subqueryBasis.Projector; + + subqueryElemExpr = Visit(elementSelector.Body); + } + + var elementAlias = Alias.New(); + var elementPC = ProjectColumns(subqueryElemExpr, elementAlias, subqueryBasis.Select.Alias); + var elementSubquery = new ProjectionExpression(new SelectExpression(elementAlias, elementPC.Columns, subqueryBasis.Select, subqueryCorrelation), elementPC.Projector); + var alias = Alias.New(); + var info = new GroupByDescriptor(alias, elemExpr); + + GroupByMap.Add(elementSubquery, info); + + Expression resultExpr; + + if (resultSelector is not null) + { + var saveGroupElement = CurrentGroupElement; + + CurrentGroupElement = elementSubquery; + + ParameterMapping[resultSelector.Parameters[0]] = keyProjection.Projector; + ParameterMapping[resultSelector.Parameters[1]] = elementSubquery; + + resultExpr = Visit(resultSelector.Body); + + CurrentGroupElement = saveGroupElement; + } + else + { + resultExpr = Expression.New(typeof(Grouping<,>).MakeGenericType(keyExpr.Type, subqueryElemExpr.Type).GetTypeInfo().DeclaredConstructors.First(), + new Expression[] { keyExpr, elementSubquery }); + + resultExpr = Expression.Convert(resultExpr, typeof(IGrouping<,>).MakeGenericType(keyExpr.Type, subqueryElemExpr.Type)); + } + + var pc = ProjectColumns(resultExpr, alias, projection.Select.Alias); + + var newResult = GetNewExpression(pc.Projector); + + if (newResult is not null && newResult.Type.GetTypeInfo().IsGenericType && newResult.Type.GetGenericTypeDefinition() == typeof(Grouping<,>)) + { + var projectedElementSubquery = newResult.Arguments[1]; + + GroupByMap.Add(projectedElementSubquery, info); + } + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null, null, groupExprs), pc.Projector); + } + + private Expression BindDistinct(Expression source) + { + var projection = VisitSequence(source); + var alias = Alias.New(); + var pc = ColumnProjector.ProjectColumns(Context.Language, ProjectionAffinity.Server, projection.Projector, null, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null, null, null, true, null, null, false), pc.Projector); + } + + private Expression BindTake(Expression source, Expression take) + { + var projection = VisitSequence(source); + + take = Visit(take); + + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null, null, null, false, null, take, false), pc.Projector); + } + + private Expression BindSkip(Expression source, Expression skip) + { + var projection = VisitSequence(source); + + skip = Visit(skip); + + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null, null, null, false, skip, null, false), pc.Projector); + } + + private Expression BindFirst(Expression source, LambdaExpression predicate, string kind, bool isRoot) + { + var projection = VisitSequence(source); + Expression? where = null; + + if (predicate is not null) + { + ParameterMapping[predicate.Parameters[0]] = projection.Projector; + where = Visit(predicate.Body); + } + + var isFirst = kind.StartsWith("First"); + var isLast = kind.StartsWith("Last"); + var take = (isFirst || isLast) ? Expression.Constant(1) : null; + + if (take is not null || where is not null) + { + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + projection = new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, where, null, null, false, null, take, isLast), pc.Projector); + } + + if (isRoot) + { + var elementType = projection.Projector.Type; + var p = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(elementType), "p"); + var gator = Expression.Lambda(Expression.Call(typeof(Enumerable), kind, new Type[] { elementType }, p), p); + + return new ProjectionExpression(projection.Select, projection.Projector, gator); + } + + return projection; + } + + private Expression BindAnyAll(Expression source, MethodInfo method, LambdaExpression predicate, bool isRoot) + { + var isAll = string.Equals(method.Name, "All", StringComparison.Ordinal); + var constSource = source as ConstantExpression; + + if (constSource is not null && !IsQuery(constSource)) + { + System.Diagnostics.Debug.Assert(!isRoot); + Expression where = null; + + foreach (object value in (IEnumerable)constSource.Value) + { + var expr = Expression.Invoke(predicate, Expression.Constant(value, predicate.Parameters[0].Type)); + + if (where is null) + where = expr; + else if (isAll) + where = where.And(expr); + else + where = where.Or(expr); + } + + return Visit(where); + } + else + { + if (isAll) + predicate = Expression.Lambda(Expression.Not(predicate.Body), predicate.Parameters.ToArray()); + + if (predicate is not null) + source = Expression.Call(typeof(Enumerable), "Where", method.GetGenericArguments(), source, predicate); + + var projection = VisitSequence(source); + Expression result = new ExistsExpression(projection.Select); + + if (isAll) + result = Expression.Not(result); + + if (isRoot) + { + if (Context.Language.AllowSubqueryInSelectWithoutFrom) + return GetSingletonSequence(result, "SingleOrDefault"); + else + { + var colType = Context.Language.TypeSystem.ResolveColumnType(typeof(int)); + var newSelect = projection.Select.SetColumns(new[] { new ColumnDeclaration("value", new AggregateExpression(typeof(int), "Count", null, false), colType) }); + var colx = new ColumnExpression(typeof(int), colType, newSelect.Alias, "value"); + var exp = isAll ? colx.Equal(Expression.Constant(0)) : colx.GreaterThan(Expression.Constant(0)); + + return new ProjectionExpression(newSelect, exp, Aggregator.GetAggregator(typeof(bool), typeof(IEnumerable))); + } + } + + return result; + } + } + + private Expression BindContains(Expression source, Expression match, bool isRoot) + { + var constSource = source as ConstantExpression; + + if (constSource is not null && !IsQuery(constSource)) + { + System.Diagnostics.Debug.Assert(!isRoot); + var values = new List(); + + foreach (object value in (IEnumerable)constSource.Value) + values.Add(Expression.Constant(Convert.ChangeType(value, match.Type), match.Type)); + + match = Visit(match); + + return new InExpression(match, values); + } + else if (isRoot && !Context.Language.AllowSubqueryInSelectWithoutFrom) + { + var p = Expression.Parameter(Enumerables.GetEnumerableElementType(source.Type), "x"); + var predicate = Expression.Lambda(p.Equal(match), p); + var exp = Expression.Call(typeof(Queryable), "Any", new Type[] { p.Type }, source, predicate); + + Expression = exp; + + return Visit(exp); + } + else + { + var projection = VisitSequence(source); + + match = Visit(match); + + var result = new InExpression(match, projection.Select); + + if (isRoot) + return GetSingletonSequence(result, "SingleOrDefault"); + + return result; + } + } + + private Expression BindCast(Expression source, Type targetElementType) + { + var projection = VisitSequence(source); + var elementType = GetTrueUnderlyingType(projection.Projector); + + if (!targetElementType.IsAssignableFrom(elementType)) + throw new InvalidOperationException($"Cannot cast elements from type '{elementType}' to type '{targetElementType}'"); + + return projection; + } + + private Expression BindIntersect(Expression outerSource, Expression innerSource, bool negate) + { + /* + * SELECT * FROM outer WHERE EXISTS(SELECT * FROM inner WHERE inner = outer)) + */ + var outerProjection = VisitSequence(outerSource); + var innerProjection = VisitSequence(innerSource); + + Expression exists = new ExistsExpression(new SelectExpression(Alias.New(), null, innerProjection.Select, innerProjection.Projector.Equal(outerProjection.Projector))); + + if (negate) + exists = Expression.Not(exists); + + var alias = Alias.New(); + var pc = ProjectColumns(outerProjection.Projector, alias, outerProjection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, outerProjection.Select, exists), pc.Projector, outerProjection.Aggregator); + } + + private Expression BindReverse(Expression expression) + { + var projection = VisitSequence(expression); + var alias = Alias.New(); + var pc = ProjectColumns(projection.Projector, alias, projection.Select.Alias); + + return new ProjectionExpression(new SelectExpression(alias, pc.Columns, projection.Select, null).SetReverse(true), pc.Projector); + } + + private static Expression? BuildPredicateWithNullsEqual(IEnumerable source1, IEnumerable source2) + { + var en1 = source1.GetEnumerator(); + var en2 = source2.GetEnumerator(); + Expression? result = null; + + while (en1.MoveNext() && en2.MoveNext()) + { + var compare = Expression.Or(new IsNullExpression(en1.Current).And(new IsNullExpression(en2.Current)), en1.Current.Equal(en2.Current)); + + result = (result is null) ? compare : result.And(compare); + } + + return result; + } + + private static bool IsQuery(Expression expression) + { + var elementType = Enumerables.GetEnumerableElementType(expression.Type); + + return elementType is not null && typeof(IQueryable<>).MakeGenericType(elementType).IsAssignableFrom(expression.Type); + } + + private Expression GetSingletonSequence(Expression expr, string aggregator) + { + var p = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(expr.Type), "p"); + LambdaExpression gator = null; + + if (aggregator is not null) + gator = Expression.Lambda(Expression.Call(typeof(Enumerable), aggregator, new Type[] { expr.Type }, p), p); + + var alias = Alias.New(); + var colType = Context.Language.TypeSystem.ResolveColumnType(expr.Type); + var select = new SelectExpression(alias, new[] { new ColumnDeclaration("value", expr, colType) }, null, null); + + return new ProjectionExpression(select, new ColumnExpression(expr.Type, colType, alias, "value"), gator); + } + + private static Type GetTrueUnderlyingType(Expression expression) + { + while (expression.NodeType == ExpressionType.Convert) + expression = ((UnaryExpression)expression).Operand; + + return expression.Type; + } + + private static bool MembersMatch(MemberInfo a, MemberInfo b) + { + if (a.Name == b.Name) + return true; + + if (a is MethodInfo && b is PropertyInfo info) + return a.Name == info.GetMethod.Name; + else if (a is PropertyInfo info1 && b is MethodInfo) + return info1.GetMethod.Name == b.Name; + + return false; + } + + private static object? GetValue(object instance, MemberInfo member) + { + var fi = member as FieldInfo; + + if (fi is not null) + return fi.GetValue(instance); + + var pi = member as PropertyInfo; + + if (pi is not null) + return pi.GetValue(instance, null); + + return null; + } + + private static object? GetDefault(Type type) + { + if (!type.GetTypeInfo().IsValueType || Nullables.IsNullableType(type)) + return null; + else + return Activator.CreateInstance(type); + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (IsQuery(expression)) + { + var q = (IQueryable)expression.Value; + + if (q.Expression.NodeType == ExpressionType.Constant) + { + var mapping = MappingsCache.Get(q.ElementType); + + return VisitSequence(mapping.CreateExpression(Context)); + } + else + { + var pev = PartialEvaluator.Eval(Context, q.Expression); + + return Visit(pev); + } + } + + return expression; + } + + protected override Expression VisitParameter(ParameterExpression expression) + { + if (ParameterMapping.TryGetValue(expression, out Expression? e)) + return e; + + return expression; + } + + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (expression.Expression is LambdaExpression lambda) + { + for (var i = 0; i < lambda.Parameters.Count; i++) + ParameterMapping[lambda.Parameters[i]] = expression.Arguments[i]; + + return Visit(lambda.Body); + } + + return base.VisitInvocation(expression); + } + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + if (expression.Expression is not null + && expression.Expression.NodeType == ExpressionType.Parameter + && !ParameterMapping.ContainsKey((ParameterExpression)expression.Expression) + && IsQuery(expression)) + { + //var mapping = MappingsCache.Get(); + + + // return this.VisitSequence(MappingsCache.Mapper.GetQueryExpression(Mapper.Mapping.GetMapping(expression.Member))); + } + + var source = Visit(expression.Expression); + + if (Context.Language.IsAggregate(expression.Member) && IsRemoteQuery(source)) + return BindAggregate(expression.Expression, expression.Member.Name, Members.GetMemberType(expression.Member), null, expression == Expression); + + var result = Bind(source, expression.Member); + var mex = result as MemberExpression; + + if (mex is not null && mex.Member == expression.Member && mex.Expression == expression.Expression) + return expression; + + return result; + } + + private bool IsRemoteQuery(Expression expression) + { + if (expression.NodeType.IsDatabaseExpression()) + return true; + + switch (expression.NodeType) + { + case ExpressionType.MemberAccess: + return IsRemoteQuery(((MemberExpression)expression).Expression); + case ExpressionType.Call: + var mc = (MethodCallExpression)expression; + + if (mc.Object is not null) + return IsRemoteQuery(mc.Object); + else if (mc.Arguments.Count > 0) + return IsRemoteQuery(mc.Arguments[0]); + break; + } + + return false; + } +} diff --git a/Connected.Expressions/Translation/ColumnDeclaration.cs b/Connected.Expressions/Translation/ColumnDeclaration.cs new file mode 100644 index 0000000..a4a8a11 --- /dev/null +++ b/Connected.Expressions/Translation/ColumnDeclaration.cs @@ -0,0 +1,27 @@ +using Connected.Expressions.Languages; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation; + +public sealed class ColumnDeclaration +{ + public ColumnDeclaration(string name, Expression expression, DataType dataType) + { + if (name is null) + throw new ArgumentNullException(nameof(name)); + + if (expression is null) + throw new ArgumentNullException(nameof(expression)); + + if (dataType is null) + throw new ArgumentNullException(nameof(dataType)); + + Name = name; + Expression = expression; + DataType = dataType; + } + + public string Name { get; } + public Expression Expression { get; } + public DataType DataType { get; } +} diff --git a/Connected.Expressions/Translation/GroupByDescriptor.cs b/Connected.Expressions/Translation/GroupByDescriptor.cs new file mode 100644 index 0000000..342db76 --- /dev/null +++ b/Connected.Expressions/Translation/GroupByDescriptor.cs @@ -0,0 +1,15 @@ +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation; + +internal sealed class GroupByDescriptor +{ + public GroupByDescriptor(Alias alias, Expression element) + { + Alias = alias; + Element = element; + } + + public Alias Alias { get; } + public Expression Element { get; } +} diff --git a/Connected.Expressions/Translation/Grouping.cs b/Connected.Expressions/Translation/Grouping.cs new file mode 100644 index 0000000..df0ce4f --- /dev/null +++ b/Connected.Expressions/Translation/Grouping.cs @@ -0,0 +1,28 @@ +using System.Collections; + +namespace Connected.Expressions.Translation; + +internal sealed class Grouping : IGrouping +{ + public Grouping(TKey key, IEnumerable group) + { + Key = key; + Group = group; + } + + public TKey Key { get; } + private IEnumerable Group { get; set; } + + public IEnumerator GetEnumerator() + { + if (!(Group is List)) + Group = Group.ToList(); + + return Group.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return Group.GetEnumerator(); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/HashedExpression.cs b/Connected.Expressions/Translation/HashedExpression.cs new file mode 100644 index 0000000..334a8ef --- /dev/null +++ b/Connected.Expressions/Translation/HashedExpression.cs @@ -0,0 +1,34 @@ +using System.Linq.Expressions; +using Connected.Expressions.Comparers; + +namespace Connected.Expressions.Translation; + +internal readonly struct HashedExpression : IEquatable +{ + private readonly Expression _expression; + private readonly int _hashCode; + + public HashedExpression(Expression expression) + { + _expression = expression; + _hashCode = Hasher.ComputeHash(expression); + } + + public override bool Equals(object? obj) + { + if (obj is not HashedExpression) + return false; + + return Equals((HashedExpression)obj); + } + + public bool Equals(HashedExpression other) + { + return _hashCode == other._hashCode && DatabaseComparer.AreEqual(_expression, other._expression); + } + + public override int GetHashCode() + { + return _hashCode; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Hasher.cs b/Connected.Expressions/Translation/Hasher.cs new file mode 100644 index 0000000..66d3458 --- /dev/null +++ b/Connected.Expressions/Translation/Hasher.cs @@ -0,0 +1,25 @@ +using System.Linq.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class Hasher : DatabaseVisitor +{ + private int _hc; + + internal static int ComputeHash(Expression expression) + { + var hasher = new Hasher(); + + hasher.Visit(expression); + + return hasher._hc; + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + _hc += (expression.Value is not null) ? expression.Value.GetHashCode() : 0; + + return expression; + } +} diff --git a/Connected.Expressions/Translation/Optimization/RedundantColumns.cs b/Connected.Expressions/Translation/Optimization/RedundantColumns.cs new file mode 100644 index 0000000..79e2f31 --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/RedundantColumns.cs @@ -0,0 +1,93 @@ +using System.Collections; +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Optimization; + +internal sealed class RedundantColumns : DatabaseVisitor +{ + private RedundantColumns() + { + Map = new(); + } + + private Dictionary Map { get; set; } + + public static Expression Remove(Expression expression) + { + if (new RedundantColumns().Visit(expression) is not Expression redundantColumnsExpression) + throw new NullReferenceException(nameof(redundantColumnsExpression)); + + return redundantColumnsExpression; + } + + protected override Expression VisitColumn(ColumnExpression column) + { + if (Map.TryGetValue(column, out ColumnExpression? mapped)) + return mapped; + + return column; + } + + protected override Expression VisitSelect(SelectExpression select) + { + select = (SelectExpression)base.VisitSelect(select); + + var cols = select.Columns.OrderBy(c => c.Name).ToList(); + var removed = new BitArray(select.Columns.Count); + var anyRemoved = false; + + for (var i = 0; i < cols.Count - 1; i++) + { + var ci = cols[i]; + var cix = ci.Expression as ColumnExpression; + var qt = cix is not null ? cix.QueryType : ci.DataType; + var cxi = new ColumnExpression(ci.Expression.Type, qt, select.Alias, ci.Name); + + for (var j = i + 1; j < cols.Count; j++) + { + if (!removed.Get(j)) + { + var cj = cols[j]; + + if (SameExpression(ci.Expression, cj.Expression)) + { + var cxj = new ColumnExpression(cj.Expression.Type, qt, select.Alias, cj.Name); + + Map.Add(cxj, cxi); + + removed.Set(j, true); + anyRemoved = true; + } + } + } + } + + if (anyRemoved) + { + var newDecls = new List(); + + for (var i = 0; i < cols.Count; i++) + { + if (!removed.Get(i)) + newDecls.Add(cols[i]); + } + + select = select.SetColumns(newDecls); + } + + return select; + } + + private static bool SameExpression(Expression a, Expression b) + { + if (a == b) + return true; + + var ca = a as ColumnExpression; + var cb = b as ColumnExpression; + + return ca is not null && cb is not null && ca.Alias == cb.Alias && ca.Name == cb.Name; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Optimization/RedundantJoins.cs b/Connected.Expressions/Translation/Optimization/RedundantJoins.cs new file mode 100644 index 0000000..d1477fd --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/RedundantJoins.cs @@ -0,0 +1,86 @@ +using System.Linq.Expressions; +using Connected.Expressions.Collections; +using Connected.Expressions.Comparers; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Optimization; + +internal sealed class RedundantJoins : DatabaseVisitor +{ + private RedundantJoins() + { + Map = new Dictionary(); + } + private Dictionary Map { get; } + + public static Expression Remove(Expression expression) + { + if (new RedundantJoins().Visit(expression) is not Expression redundantJoinExpression) + throw new NullReferenceException(nameof(redundantJoinExpression)); + + return redundantJoinExpression; + } + + protected override Expression VisitJoin(JoinExpression expression) + { + var result = base.VisitJoin(expression); + + var joinExpression = result as JoinExpression; + + if (joinExpression is not null) + { + var right = joinExpression.Right as AliasedExpression; + + if (right is not null) + { + var similarRight = FindSimilarRight(expression.Left as JoinExpression, joinExpression) as AliasedExpression; + + if (similarRight is not null) + { + Map.Add(right.Alias, similarRight.Alias); + + return joinExpression.Left; + } + } + } + + return result; + } + + private Expression? FindSimilarRight(JoinExpression? join, JoinExpression compareTo) + { + if (join is null) + return null; + + if (join.Join == compareTo.Join) + { + if (join.Right.NodeType == compareTo.Right.NodeType && DatabaseComparer.AreEqual(join.Right, compareTo.Right)) + { + if (join.Condition == compareTo.Condition) + return join.Right; + + var scope = new ScopedDictionary(null); + + scope.Add(((AliasedExpression)join.Right).Alias, ((AliasedExpression)compareTo.Right).Alias); + + if (DatabaseComparer.AreEqual(null, scope, join.Condition, compareTo.Condition)) + return join.Right; + } + } + + var result = FindSimilarRight(join.Left as JoinExpression, compareTo); + + result ??= FindSimilarRight(join.Right as JoinExpression, compareTo); + + return result; + } + + protected override Expression VisitColumn(ColumnExpression column) + { + if (Map.TryGetValue(column.Alias, out Alias? mapped)) + return new ColumnExpression(column.Type, column.QueryType, mapped, column.Name); + + return column; + } +} diff --git a/Connected.Expressions/Translation/Optimization/RedundantSubqueries.cs b/Connected.Expressions/Translation/Optimization/RedundantSubqueries.cs new file mode 100644 index 0000000..248da63 --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/RedundantSubqueries.cs @@ -0,0 +1,73 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class RedundantSubqueries : DatabaseVisitor +{ + private RedundantSubqueries() + { + } + + public static Expression Remove(Expression expression) + { + if (new RedundantSubqueries().Visit(expression) is not Expression redundandSubqueryExpression) + throw new NullReferenceException(nameof(redundandSubqueryExpression)); + + return SubqueryMerger.Merge(redundandSubqueryExpression); + } + + protected override Expression VisitSelect(SelectExpression select) + { + select = (SelectExpression)base.VisitSelect(select); + + var redundant = RedundandSubqueriesResolver.Resolve(select.From); + + if (redundant is not null) + select = Subqueries.Remove(select, redundant); + + return select; + } + + protected override Expression VisitProjection(ProjectionExpression proj) + { + proj = (ProjectionExpression)base.VisitProjection(proj); + + if (proj.Select.From is SelectExpression) + { + var redundant = RedundandSubqueriesResolver.Resolve(proj.Select); + + if (redundant is not null) + proj = Subqueries.Remove(proj, redundant); + } + + return proj; + } + + internal static bool IsNameMapProjection(SelectExpression select) + { + if (select.From is TableExpression) + return false; + + + if (select.From is not SelectExpression fromSelect || select.Columns.Count != fromSelect.Columns.Count) + return false; + + var fromColumns = fromSelect.Columns; + + for (var i = 0; i < select.Columns.Count; i++) + { + if (select.Columns[i].Expression is not ColumnExpression col || !(col.Name == fromColumns[i].Name)) + return false; + } + + return true; + } + + internal static bool IsInitialProjection(SelectExpression select) + { + return select.From is TableExpression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Optimization/Subqueries.cs b/Connected.Expressions/Translation/Optimization/Subqueries.cs new file mode 100644 index 0000000..674d029 --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/Subqueries.cs @@ -0,0 +1,74 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class Subqueries : DatabaseVisitor +{ + private Subqueries(IEnumerable selectsToRemove) + { + SelectsToRemove = new HashSet(selectsToRemove); + Map = SelectsToRemove.ToDictionary(d => d.Alias, d => d.Columns.ToDictionary(d2 => d2.Name, d2 => d2.Expression)); + } + + private HashSet SelectsToRemove { get; set; } + private Dictionary> Map { get; set; } + + public static SelectExpression Remove(SelectExpression outerSelect, params SelectExpression[] selectsToRemove) + { + return Remove(outerSelect, (IEnumerable)selectsToRemove); + } + + public static SelectExpression Remove(SelectExpression outerSelect, IEnumerable selectsToRemove) + { + if (new Subqueries(selectsToRemove).Visit(outerSelect) is not SelectExpression selectRemoveExpression) + throw new NullReferenceException(nameof(selectRemoveExpression)); + + return selectRemoveExpression; + } + + public static ProjectionExpression Remove(ProjectionExpression projection, params SelectExpression[] selectsToRemove) + { + return Remove(projection, (IEnumerable)selectsToRemove); + } + + public static ProjectionExpression Remove(ProjectionExpression projection, IEnumerable selectsToRemove) + { + if (new Subqueries(selectsToRemove).Visit(projection) is not ProjectionExpression projectionRemoveExpression) + throw new NullReferenceException(nameof(projectionRemoveExpression)); + + return projectionRemoveExpression; + } + + protected override Expression VisitSelect(SelectExpression expression) + { + if (SelectsToRemove.Contains(expression)) + { + if (Visit(expression.From) is not Expression fromExpression) + throw new NullReferenceException(nameof(fromExpression)); + + return fromExpression; + } + else + return base.VisitSelect(expression); + } + + protected override Expression VisitColumn(ColumnExpression expression) + { + if (Map.TryGetValue(expression.Alias, out Dictionary? nameMap)) + { + if (nameMap.TryGetValue(expression.Name, out Expression? expr)) + { + if (Visit(expr) is not Expression columnExpression) + throw new NullReferenceException(nameof(columnExpression)); + + return columnExpression; + } + + throw new NullReferenceException($"Reference to undefined column ({expression.Name})"); + } + + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Optimization/SubqueryMerger.cs b/Connected.Expressions/Translation/Optimization/SubqueryMerger.cs new file mode 100644 index 0000000..919f42d --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/SubqueryMerger.cs @@ -0,0 +1,135 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class SubqueryMerger : DatabaseVisitor +{ + private SubqueryMerger() + { + } + + internal static Expression Merge(Expression expression) + { + if (new SubqueryMerger().Visit(expression) is not Expression subqueryExpression) + throw new NullReferenceException(nameof(subqueryExpression)); + + return subqueryExpression; + } + + private bool IsTopLevel { get; set; } = true; + + protected override Expression VisitSelect(SelectExpression expression) + { + var wasTopLevel = IsTopLevel; + + IsTopLevel = false; + + expression = (SelectExpression)base.VisitSelect(expression); + + while (CanMergeWithFrom(expression, wasTopLevel)) + { + if (GetLeftMostSelect(expression.From) is not SelectExpression fromSelectExpression) + throw new NullReferenceException(nameof(fromSelectExpression)); + + expression = Subqueries.Remove(expression, fromSelectExpression); + + var where = expression.Where; + + if (fromSelectExpression.Where is not null) + { + if (where is not null) + where = fromSelectExpression.Where.And(where); + else + where = fromSelectExpression.Where; + } + + var orderBy = expression.OrderBy is not null && expression.OrderBy.Count > 0 ? expression.OrderBy : fromSelectExpression.OrderBy; + var groupBy = expression.GroupBy is not null && expression.GroupBy.Count > 0 ? expression.GroupBy : fromSelectExpression.GroupBy; + var skip = expression.Skip is not null ? expression.Skip : fromSelectExpression.Skip; + var take = expression.Take is not null ? expression.Take : fromSelectExpression.Take; + var isDistinct = expression.IsDistinct | fromSelectExpression.IsDistinct; + + if (where != expression.Where || orderBy != expression.OrderBy || groupBy != expression.GroupBy || isDistinct != expression.IsDistinct || skip != expression.Skip || take != expression.Take) + expression = new SelectExpression(expression.Alias, expression.Columns, expression.From, where, orderBy, groupBy, isDistinct, skip, take, expression.IsReverse); + } + + return expression; + } + + private static SelectExpression? GetLeftMostSelect(Expression source) + { + var select = source as SelectExpression; + + if (select is not null) + return select; + + if (source is JoinExpression join) + return GetLeftMostSelect(join.Left); + + return null; + } + + private static bool IsColumnProjection(SelectExpression select) + { + for (var i = 0; i < select.Columns.Count; i++) + { + var cd = select.Columns[i]; + + if (cd.Expression.NodeType != (ExpressionType)DatabaseExpressionType.Column && cd.Expression.NodeType != ExpressionType.Constant) + return false; + } + + return true; + } + + private static bool CanMergeWithFrom(SelectExpression select, bool isTopLevel) + { + var fromSelect = GetLeftMostSelect(select.From); + + if (fromSelect is null) + return false; + + if (!IsColumnProjection(fromSelect)) + return false; + + var selHasNameMapProjection = RedundantSubqueries.IsNameMapProjection(select); + var selHasOrderBy = select.OrderBy is not null && select.OrderBy.Count > 0; + var selHasGroupBy = select.GroupBy is not null && select.GroupBy.Count > 0; + var selHasAggregates = AggregateChecker.HasAggregates(select); + var selHasJoin = select.From is JoinExpression; + var frmHasOrderBy = fromSelect.OrderBy is not null && fromSelect.OrderBy.Count > 0; + var frmHasGroupBy = fromSelect.GroupBy is not null && fromSelect.GroupBy.Count > 0; + var frmHasAggregates = AggregateChecker.HasAggregates(fromSelect); + + if (selHasOrderBy && frmHasOrderBy) + return false; + + if (selHasGroupBy && frmHasGroupBy) + return false; + + if (select.IsReverse || fromSelect.IsReverse) + return false; + + if (frmHasOrderBy && (selHasGroupBy || selHasAggregates || select.IsDistinct)) + return false; + + if (frmHasGroupBy) + return false; + + if (fromSelect.Take is not null && (select.Take is not null || select.Skip is not null || select.IsDistinct || selHasAggregates || selHasGroupBy || selHasJoin)) + return false; + + if (fromSelect.Skip is not null && (select.Skip is not null || select.IsDistinct || selHasAggregates || selHasGroupBy || selHasJoin)) + return false; + + if (fromSelect.IsDistinct && (select.Take is not null || select.Skip is not null || !selHasNameMapProjection || selHasGroupBy || selHasAggregates || (selHasOrderBy && !isTopLevel) || selHasJoin)) + return false; + + if (frmHasAggregates && (select.Take is not null || select.Skip is not null || select.IsDistinct || selHasAggregates || selHasGroupBy || selHasJoin)) + return false; + + return true; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Optimization/UnusedColumns.cs b/Connected.Expressions/Translation/Optimization/UnusedColumns.cs new file mode 100644 index 0000000..f028da9 --- /dev/null +++ b/Connected.Expressions/Translation/Optimization/UnusedColumns.cs @@ -0,0 +1,202 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Optimization; + +internal sealed class UnusedColumns : DatabaseVisitor +{ + private UnusedColumns() + { + AllUsed = new(); + } + private Dictionary> AllUsed { get; set; } + + private bool RetainAllColumns { get; set; } + + public static Expression Remove(Expression expression) + { + if (new UnusedColumns().Visit(expression) is not Expression unusedColumnExpression) + throw new NullReferenceException(nameof(unusedColumnExpression)); + + return unusedColumnExpression; + } + + private void MarkAsUsed(Alias alias, string name) + { + if (!AllUsed.TryGetValue(alias, out HashSet? columns)) + { + columns = new HashSet(); + + AllUsed.Add(alias, columns); + } + + columns.Add(name); + } + + private bool IsUsed(Alias alias, string name) + { + if (AllUsed.TryGetValue(alias, out HashSet? columnsUsed)) + { + if (columnsUsed is not null) + return columnsUsed.Contains(name); + } + + return false; + } + + private void ClearUsed(Alias alias) + { + AllUsed[alias] = new HashSet(); + } + + protected override Expression VisitColumn(ColumnExpression expression) + { + MarkAsUsed(expression.Alias, expression.Name); + + return expression; + } + + protected override Expression VisitSubquery(SubqueryExpression subquery) + { + if ((subquery.NodeType == (ExpressionType)DatabaseExpressionType.Scalar || subquery.NodeType == (ExpressionType)DatabaseExpressionType.In) && subquery.Select is not null) + { + System.Diagnostics.Debug.Assert(subquery.Select.Columns.Count == 1); + MarkAsUsed(subquery.Select.Alias, subquery.Select.Columns[0].Name); + } + + return base.VisitSubquery(subquery); + } + + protected override Expression VisitSelect(SelectExpression select) + { + var columns = select.Columns; + var wasRetained = RetainAllColumns; + + RetainAllColumns = false; + + List? alternate = null; + + for (var i = 0; i < select.Columns.Count; i++) + { + var decl = select.Columns[i]; + + if (wasRetained || select.IsDistinct || IsUsed(select.Alias, decl.Name)) + { + if (Visit(decl.Expression) is not Expression declarationExpression) + throw new NullReferenceException(nameof(declarationExpression)); + + if (declarationExpression != decl.Expression) + decl = new ColumnDeclaration(decl.Name, declarationExpression, decl.DataType); + } + else + decl = null; + + if (decl != select.Columns[i] && alternate is null) + { + alternate = new List(); + + for (var j = 0; j < i; j++) + alternate.Add(select.Columns[j]); + } + + if (decl is not null && alternate is not null) + alternate.Add(decl); + } + + if (alternate is not null) + columns = alternate.AsReadOnly(); + + var take = Visit(select.Take); + var skip = Visit(select.Skip); + var groupbys = VisitExpressionList(select.GroupBy); + var orderbys = VisitOrderBy(select.OrderBy); + var where = Visit(select.Where); + + if (Visit(select.From) is not Expression fromExpression) + throw new NullReferenceException(nameof(fromExpression)); + + ClearUsed(select.Alias); + + if (columns != select.Columns || take != select.Take || skip != select.Skip || orderbys != select.OrderBy || groupbys != select.GroupBy || where != select.Where || fromExpression != select.From) + select = new SelectExpression(select.Alias, columns, fromExpression, where, orderbys, groupbys, select.IsDistinct, skip, take, select.IsReverse); + + RetainAllColumns = wasRetained; + + return select; + } + + protected override Expression VisitAggregate(AggregateExpression expression) + { + /* + * COUNT(*) forces all columns to be retained in subquery + */ + if (string.Equals(expression.AggregateName, "Count", StringComparison.OrdinalIgnoreCase) && expression.Argument is null) + RetainAllColumns = true; + + return base.VisitAggregate(expression); + } + + protected override Expression VisitProjection(ProjectionExpression expression) + { + if (Visit(expression.Projector) is not Expression projector) + throw new NullReferenceException(nameof(projector)); + + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + return UpdateProjection(expression, selectExpression, projector, expression.Aggregator); + } + + protected override Expression VisitClientJoin(ClientJoinExpression expression) + { + var innerKey = VisitExpressionList(expression.InnerKey); + var outerKey = VisitExpressionList(expression.OuterKey); + + if (Visit(expression.Projection) is not ProjectionExpression projectionExpression) + throw new NullReferenceException(nameof(projectionExpression)); + + if (projectionExpression != expression.Projection || innerKey != expression.InnerKey || outerKey != expression.OuterKey) + return new ClientJoinExpression(projectionExpression, outerKey, innerKey); + + return expression; + } + + protected override Expression VisitJoin(JoinExpression expression) + { + if (expression.Join == JoinType.SingletonLeftOuter) + { + var right = Visit(expression.Right); + var ax = right as AliasedExpression; + + if (ax is not null && !AllUsed.ContainsKey(ax.Alias)) + { + if (Visit(expression.Left) is not Expression leftOuterExpression) + throw new NullReferenceException(nameof(leftOuterExpression)); + + return leftOuterExpression; + } + + if (Visit(expression.Condition) is not Expression conditionExpression) + throw new NullReferenceException(nameof(conditionExpression)); + + if (Visit(expression.Left) is not Expression leftExpression) + throw new NullReferenceException(nameof(leftExpression)); + + if (Visit(expression.Right) is not Expression rightExpression) + throw new NullReferenceException(nameof(rightExpression)); + + return UpdateJoin(expression, expression.Join, leftExpression, rightExpression, conditionExpression); + } + else + { + if (Visit(expression.Condition) is not Expression conditionExpression) + throw new NullReferenceException(nameof(conditionExpression)); + + var right = VisitSource(expression.Right); + var left = VisitSource(expression.Left); + + return UpdateJoin(expression, expression.Join, left, right, conditionExpression); + } + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Parameterizer.cs b/Connected.Expressions/Translation/Parameterizer.cs new file mode 100644 index 0000000..e3d0543 --- /dev/null +++ b/Connected.Expressions/Translation/Parameterizer.cs @@ -0,0 +1,159 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class Parameterizer : DatabaseVisitor +{ + private Parameterizer(QueryLanguage language) + { + Language = language; + Map = new(); + ParameterMap = new(); + } + + private QueryLanguage Language { get; } + private Dictionary Map { get; } + private Dictionary ParameterMap { get; } + private int Counter { get; set; } + + public static Expression Parameterize(QueryLanguage language, Expression expression) + { + if (new Parameterizer(language).Visit(expression) is not Expression parameterizedExpression) + throw new NullReferenceException(nameof(parameterizedExpression)); + + return parameterizedExpression; + } + + protected override Expression VisitProjection(ProjectionExpression expression) + { + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + return UpdateProjection(expression, selectExpression, expression.Projector, expression.Aggregator); + } + + protected override Expression VisitUnary(UnaryExpression expression) + { + if (expression.NodeType == ExpressionType.Convert && expression.Operand.NodeType == ExpressionType.ArrayIndex) + { + var b = (BinaryExpression)expression.Operand; + + if (IsConstantOrParameter(b.Left) && IsConstantOrParameter(b.Right)) + return GetNamedValue(expression); + } + + return base.VisitUnary(expression); + } + + private static bool IsConstantOrParameter(Expression expression) + { + return expression is not null && (expression.NodeType == ExpressionType.Constant || expression.NodeType == ExpressionType.Parameter); + } + + protected override Expression VisitBinary(BinaryExpression expression) + { + if (Visit(expression.Left) is not Expression leftBinaryExpression) + throw new NullReferenceException(nameof(leftBinaryExpression)); + + if (Visit(expression.Right) is not Expression rightBinaryExpression) + throw new NullReferenceException(nameof(rightBinaryExpression)); + + if (leftBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.NamedValue && rightBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.Column) + { + var nv = (NamedValueExpression)leftBinaryExpression; + var c = (ColumnExpression)rightBinaryExpression; + + leftBinaryExpression = new NamedValueExpression(nv.Name, c.QueryType, nv.Value); + } + else if (rightBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.NamedValue && leftBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.Column) + { + var nv = (NamedValueExpression)rightBinaryExpression; + var c = (ColumnExpression)leftBinaryExpression; + + rightBinaryExpression = new NamedValueExpression(nv.Name, c.QueryType, nv.Value); + } + + return UpdateBinary(expression, leftBinaryExpression, rightBinaryExpression, expression.Conversion, expression.IsLiftedToNull, expression.Method); + } + + protected override ColumnAssignment VisitColumnAssignment(ColumnAssignment ca) + { + ca = base.VisitColumnAssignment(ca); + + var expression = ca.Expression; + var nv = expression as NamedValueExpression; + + if (nv is not null) + expression = new NamedValueExpression(nv.Name, ca.Column.QueryType, nv.Value); + + return UpdateColumnAssignment(ca, ca.Column, expression); + } + + protected override Expression VisitConstant(ConstantExpression expression) + { + if (expression.Value is not null && !IsNumeric(expression.Value.GetType())) + { + var tv = new TypeValue(expression.Type, expression.Value); + + if (!Map.TryGetValue(tv, out NamedValueExpression? nv)) + { + var name = $"p{(Counter++)}"; + + nv = new NamedValueExpression(name, Language.TypeSystem.ResolveColumnType(expression.Type), expression); + + Map.Add(tv, nv); + } + + return nv; + } + + return expression; + } + + protected override Expression VisitParameter(ParameterExpression expression) => GetNamedValue(expression); + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + expression = (MemberExpression)base.VisitMemberAccess(expression); + + var nv = expression.Expression as NamedValueExpression; + + if (nv is not null) + { + var x = Expression.MakeMemberAccess(nv.Value, expression.Member); + + return GetNamedValue(x); + } + + return expression; + } + + private Expression GetNamedValue(Expression expression) + { + var he = new HashedExpression(expression); + + if (!ParameterMap.TryGetValue(he, out NamedValueExpression? nv)) + { + var name = "$p{(iParam++)}"; + + nv = new NamedValueExpression(name, Language.TypeSystem.ResolveColumnType(expression.Type), expression); + + ParameterMap.Add(he, nv); + } + + return nv; + } + + private static bool IsNumeric(Type type) + { + return Interop.TypeSystem.GetTypeCode(type) switch + { + TypeCode.Boolean or TypeCode.Byte or TypeCode.Decimal or TypeCode.Double or TypeCode.Int16 or TypeCode.Int32 or TypeCode.Int64 + or TypeCode.SByte or TypeCode.Single or TypeCode.UInt16 or TypeCode.UInt32 or TypeCode.UInt64 => true, + _ => false, + }; + } +} diff --git a/Connected.Expressions/Translation/Projections/ColumnProjector.cs b/Connected.Expressions/Translation/Projections/ColumnProjector.cs new file mode 100644 index 0000000..1e23c83 --- /dev/null +++ b/Connected.Expressions/Translation/Projections/ColumnProjector.cs @@ -0,0 +1,142 @@ +using System.Linq.Expressions; +using Connected.Expressions.Evaluation; +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Projections; + +internal enum ProjectionAffinity +{ + Client, + Server +} + +internal sealed class ColumnProjector : DatabaseVisitor +{ + private readonly QueryLanguage _language; + private readonly Dictionary _map; + private readonly List _columns; + private readonly HashSet _columnNames; + private readonly HashSet _candidates; + private readonly HashSet _existingAliases; + private readonly Alias _newAlias; + + private ColumnProjector(QueryLanguage language, ProjectionAffinity affinity, Expression expression, IEnumerable existingColumns, Alias newAlias, IEnumerable existingAliases) + { + _language = language; + _newAlias = newAlias; + _existingAliases = new HashSet(existingAliases); + _map = new Dictionary(); + + if (existingColumns is not null) + { + _columns = new List(existingColumns); + _columnNames = new HashSet(existingColumns.Select(c => c.Name)); + } + else + { + _columns = new List(); + _columnNames = new HashSet(); + } + + _candidates = ExpressionNominator.Nominate(Language, affinity, expression); + } + + private QueryLanguage Language => _language; + private Dictionary Map => _map; + private List Columns => _columns; + private HashSet ColumnNames => _columnNames; + private HashSet Candidates => _candidates; + private HashSet ExistingAliases => _existingAliases; + private Alias NewAlias => _newAlias; + private int ColumnCounter { get; set; } + public static ProjectedColumns ProjectColumns(QueryLanguage language, ProjectionAffinity affinity, Expression expression, + IEnumerable? existingColumns, Alias newAlias, IEnumerable existingAliases) + { + var projector = new ColumnProjector(language, affinity, expression, existingColumns, newAlias, existingAliases); + var expr = projector.Visit(expression); + + return new ProjectedColumns(expr, projector.Columns.AsReadOnly()); + } + + public static ProjectedColumns ProjectColumns(QueryLanguage language, Expression expression, IEnumerable? existingColumns, + Alias newAlias, IEnumerable existingAliases) + { + return ProjectColumns(language, ProjectionAffinity.Client, expression, existingColumns, newAlias, existingAliases); + } + + public static ProjectedColumns ProjectColumns(QueryLanguage language, ProjectionAffinity affinity, Expression expression, IEnumerable existingColumns, + Alias newAlias, params Alias[] existingAliases) + { + return ProjectColumns(language, affinity, expression, existingColumns, newAlias, (IEnumerable)existingAliases); + } + + public static ProjectedColumns ProjectColumns(QueryLanguage language, Expression expression, IEnumerable existingColumns, Alias newAlias, params Alias[] existingAliases) + { + return ProjectColumns(language, expression, existingColumns, newAlias, (IEnumerable)existingAliases); + } + + protected override Expression? Visit(Expression? expression) + { + if (Candidates.Contains(expression)) + { + if (expression.NodeType == (ExpressionType)DatabaseExpressionType.Column) + { + var column = (ColumnExpression)expression; + + if (Map.TryGetValue(column, out ColumnExpression? mapped)) + return mapped; + + foreach (ColumnDeclaration existingColumn in Columns) + { + if (existingColumn.Expression is ColumnExpression cex && cex.Alias == column.Alias && cex.Name == column.Name) + return new ColumnExpression(column.Type, column.QueryType, NewAlias, existingColumn.Name); + } + + if (ExistingAliases.Contains(column.Alias)) + { + var ordinal = Columns.Count; + var columnName = GetUniqueColumnName(column.Name); + + Columns.Add(new ColumnDeclaration(columnName, column, column.QueryType)); + + mapped = new ColumnExpression(column.Type, column.QueryType, NewAlias, columnName); + + Map.Add(column, mapped); + + ColumnNames.Add(columnName); + + return mapped; + } + + return column; + } + else + { + var columnName = GetNextColumnName(); + var colType = Language.TypeSystem.ResolveColumnType(expression.Type); + + Columns.Add(new ColumnDeclaration(columnName, expression, colType)); + + return new ColumnExpression(expression.Type, colType, NewAlias, columnName); + } + } + else + return base.Visit(expression); + } + + private bool IsColumnNameInUse(string name) => ColumnNames.Contains(name); + + private string GetUniqueColumnName(string name) + { + var baseName = name; + var suffix = 1; + + while (IsColumnNameInUse(name)) + name = baseName + suffix++; + return name; + } + + private string GetNextColumnName() => GetUniqueColumnName($"c{ColumnCounter++}"); +} diff --git a/Connected.Expressions/Translation/Projections/ProjectedColumns.cs b/Connected.Expressions/Translation/Projections/ProjectedColumns.cs new file mode 100644 index 0000000..cb300f1 --- /dev/null +++ b/Connected.Expressions/Translation/Projections/ProjectedColumns.cs @@ -0,0 +1,20 @@ +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation; + +internal sealed class ProjectedColumns +{ + private readonly Expression _projector; + private readonly ReadOnlyCollection _columns; + + public ProjectedColumns(Expression projector, ReadOnlyCollection columns) + { + _projector = projector; + _columns = columns; + } + + public Expression Projector => _projector; + public ReadOnlyCollection Columns => _columns; + +} diff --git a/Connected.Expressions/Translation/RelationshipBinder.cs b/Connected.Expressions/Translation/RelationshipBinder.cs new file mode 100644 index 0000000..e04216b --- /dev/null +++ b/Connected.Expressions/Translation/RelationshipBinder.cs @@ -0,0 +1,115 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Translation.Projections; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation; + +internal sealed class RelationshipBinder : DatabaseVisitor +{ + private RelationshipBinder(ExpressionCompilationContext context) + { + Context = context; + } + + private ExpressionCompilationContext Context { get; } + private Expression? CurrentFrom { get; set; } + + public static Expression Bind(ExpressionCompilationContext context, Expression expression) + { + if (new RelationshipBinder(context).Visit(expression) is not Expression relationshipExpression) + throw new NullReferenceException(nameof(relationshipExpression)); + + return relationshipExpression; + } + + protected override Expression VisitSelect(SelectExpression select) + { + /* + * look for association references in SelectExpression clauses + */ + var saveCurrentFrom = CurrentFrom; + + CurrentFrom = VisitSource(select.From); + + try + { + var where = Visit(select.Where); + var orderBy = VisitOrderBy(select.OrderBy); + var groupBy = VisitExpressionList(select.GroupBy); + var skip = Visit(select.Skip); + var take = Visit(select.Take); + var columns = VisitColumnDeclarations(select.Columns); + + return UpdateSelect(select, CurrentFrom, where, orderBy, groupBy, skip, take, select.IsDistinct, select.IsReverse, columns); + } + finally + { + CurrentFrom = saveCurrentFrom; + } + } + + protected override Expression VisitProjection(ProjectionExpression proj) + { + var select = (SelectExpression)Visit(proj.Select); + var saveCurrentFrom = CurrentFrom; + + CurrentFrom = select; + + try + { + var projector = Visit(proj.Projector); + + if (CurrentFrom != select) + { + var alias = Alias.New(); + var existingAliases = GetAliases(CurrentFrom); + var pc = ColumnProjector.ProjectColumns(Context.Language, projector, null, alias, existingAliases); + + projector = pc.Projector; + select = new SelectExpression(alias, pc.Columns, CurrentFrom, null); + } + + return UpdateProjection(proj, select, projector, proj.Aggregator); + } + finally + { + CurrentFrom = saveCurrentFrom; + } + } + + private static List GetAliases(Expression expression) + { + var aliases = new List(); + + GetAliases(expression); + + return aliases; + + void GetAliases(Expression e) + { + switch (e) + { + case JoinExpression j: + GetAliases(j.Left); + GetAliases(j.Right); + break; + case AliasedExpression a: + aliases.Add(a.Alias); + break; + } + } + } + + protected override Expression VisitMemberAccess(MemberExpression expression) + { + var source = Visit(expression.Expression); + var result = Binder.Bind(source, expression.Member); + var mex = result as MemberExpression; + + if (mex is not null && mex.Member == expression.Member && mex.Expression == expression.Expression) + return expression; + + return result; + } +} diff --git a/Connected.Expressions/Translation/Resolvers/AggregateResolver.cs b/Connected.Expressions/Translation/Resolvers/AggregateResolver.cs new file mode 100644 index 0000000..7832adb --- /dev/null +++ b/Connected.Expressions/Translation/Resolvers/AggregateResolver.cs @@ -0,0 +1,31 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Resolvers; + +internal sealed class AggregateResolver : DatabaseVisitor +{ + private AggregateResolver() + { + Aggregates = new(); + } + + private List Aggregates { get; } + + internal static List Resolve(Expression expression) + { + var resolver = new AggregateResolver(); + + resolver.Visit(expression); + + return resolver.Aggregates; + } + + protected override Expression VisitAggregateSubquery(AggregateSubqueryExpression aggregate) + { + Aggregates.Add(aggregate); + + return base.VisitAggregateSubquery(aggregate); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Resolvers/DeclaredAliasesResolver.cs b/Connected.Expressions/Translation/Resolvers/DeclaredAliasesResolver.cs new file mode 100644 index 0000000..7b6dad9 --- /dev/null +++ b/Connected.Expressions/Translation/Resolvers/DeclaredAliasesResolver.cs @@ -0,0 +1,38 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Resolvers; + +internal sealed class DeclaredAliasesResolver : DatabaseVisitor +{ + private DeclaredAliasesResolver() + { + Aliases = new(); + } + + private HashSet Aliases { get; set; } + + public static HashSet Resolve(Expression source) + { + var resolver = new DeclaredAliasesResolver(); + + resolver.Visit(source); + + return resolver.Aliases; + } + + protected override Expression VisitSelect(SelectExpression select) + { + Aliases.Add(select.Alias); + + return select; + } + + protected override Expression VisitTable(TableExpression table) + { + Aliases.Add(table.Alias); + + return table; + } +} diff --git a/Connected.Expressions/Translation/Resolvers/JoinColumnResolver.cs b/Connected.Expressions/Translation/Resolvers/JoinColumnResolver.cs new file mode 100644 index 0000000..1d9413b --- /dev/null +++ b/Connected.Expressions/Translation/Resolvers/JoinColumnResolver.cs @@ -0,0 +1,78 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; + +namespace Connected.Expressions.Translation.Resolvers; + +internal sealed class JoinColumnResolver +{ + private JoinColumnResolver(HashSet aliases) + { + Aliases = aliases; + Columns = new HashSet(); + } + + private HashSet Aliases { get; } + private HashSet Columns { get; } + + public static HashSet Resolve(HashSet aliases, SelectExpression select) + { + var resolver = new JoinColumnResolver(aliases); + + resolver.Resolve(select.Where); + + return resolver.Columns; + } + + private void Resolve(Expression? expression) + { + if (expression is BinaryExpression b) + { + switch (b.NodeType) + { + case ExpressionType.Equal: + case ExpressionType.NotEqual: + + if (IsExternalColumn(b.Left) && GetColumn(b.Right) is not null) + { + if (GetColumn(b.Right) is ColumnExpression right) + Columns.Add(right); + } + else if (IsExternalColumn(b.Right) && GetColumn(b.Left) is not null) + { + if (GetColumn(b.Left) is ColumnExpression left) + Columns.Add(left); + } + + break; + case ExpressionType.And: + case ExpressionType.AndAlso: + + if (b.Type == typeof(bool) || b.Type == typeof(bool?)) + { + Resolve(b.Left); + Resolve(b.Right); + } + + break; + } + } + } + + private static ColumnExpression? GetColumn(Expression exp) + { + while (exp.NodeType == ExpressionType.Convert) + exp = ((UnaryExpression)exp).Operand; + + return exp as ColumnExpression; + } + + private bool IsExternalColumn(Expression exp) + { + var col = GetColumn(exp); + + if (col is not null && !Aliases.Contains(col.Alias)) + return true; + + return false; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Resolvers/RedundandSubqueriesResolver.cs b/Connected.Expressions/Translation/Resolvers/RedundandSubqueriesResolver.cs new file mode 100644 index 0000000..f17820c --- /dev/null +++ b/Connected.Expressions/Translation/Resolvers/RedundandSubqueriesResolver.cs @@ -0,0 +1,63 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Resolvers; + +internal class RedundandSubqueriesResolver : DatabaseVisitor +{ + private RedundandSubqueriesResolver() + { + } + + private List? Redundant { get; set; } + + internal static List Resolve(Expression expression) + { + var retriever = new RedundandSubqueriesResolver(); + + retriever.Visit(expression); + + return retriever.Redundant; + } + + private static bool IsRedudantSubquery(SelectExpression expression) + { + return (IsSimpleProjection(expression) || RedundantSubqueries.IsNameMapProjection(expression)) + && !expression.IsDistinct + && !expression.IsReverse + && expression.Take is null + && expression.Skip is null + && expression.Where is null + && (expression.OrderBy is null || !expression.OrderBy.Any()) + && (expression.GroupBy is null || !expression.GroupBy.Any()); + } + + internal static bool IsSimpleProjection(SelectExpression select) + { + foreach (var decl in select.Columns) + { + if (decl.Expression is not ColumnExpression col || decl.Name != col.Name) + return false; + } + + return true; + } + + protected override Expression VisitSelect(SelectExpression expression) + { + if (IsRedudantSubquery(expression)) + { + Redundant ??= new List(); + + Redundant.Add(expression); + } + + return expression; + } + + protected override Expression VisitSubquery(SubqueryExpression expression) + { + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Resolvers/ReferencedAliasesResolver.cs b/Connected.Expressions/Translation/Resolvers/ReferencedAliasesResolver.cs new file mode 100644 index 0000000..89566d1 --- /dev/null +++ b/Connected.Expressions/Translation/Resolvers/ReferencedAliasesResolver.cs @@ -0,0 +1,31 @@ +using System.Linq.Expressions; +using Connected.Expressions.Expressions; +using Connected.Expressions.Visitors; + +namespace Connected.Expressions.Translation.Resolvers; + +internal sealed class ReferencedAliasesResolver : DatabaseVisitor +{ + private ReferencedAliasesResolver() + { + Aliases = new(); + } + + private HashSet Aliases { get; } + + public static HashSet Resolve(Expression source) + { + var resolver = new ReferencedAliasesResolver(); + + resolver.Visit(source); + + return resolver.Aliases; + } + + protected override Expression VisitColumn(ColumnExpression column) + { + Aliases.Add(column.Alias); + + return column; + } +} diff --git a/Connected.Expressions/Translation/Rewriters/AggregateRewriter.cs b/Connected.Expressions/Translation/Rewriters/AggregateRewriter.cs new file mode 100644 index 0000000..aa0013c --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/AggregateRewriter.cs @@ -0,0 +1,62 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class AggregateRewriter : DatabaseVisitor +{ + private AggregateRewriter(ExpressionCompilationContext context, Expression expr) + { + Context = context; + + Map = new(); + Lookup = AggregateResolver.Resolve(expr).ToLookup(a => a.GroupByAlias); + } + + private ExpressionCompilationContext Context { get; set; } + private ILookup Lookup { get; set; } + private Dictionary Map { get; set; } + + public static Expression Rewrite(ExpressionCompilationContext context, Expression expr) + { + if (new AggregateRewriter(context, expr).Visit(expr) is not Expression aggregateExpression) + throw new NullReferenceException(nameof(aggregateExpression)); + + return aggregateExpression; + } + + protected override Expression VisitSelect(SelectExpression select) + { + select = (SelectExpression)base.VisitSelect(select); + + if (Lookup.Contains(select.Alias)) + { + var aggColumns = new List(select.Columns); + + foreach (AggregateSubqueryExpression ae in Lookup[select.Alias]) + { + var name = $"agg{aggColumns.Count}"; + var colType = Context.Language.TypeSystem.ResolveColumnType(ae.Type); + var cd = new ColumnDeclaration(name, ae.AggregateInGroupSelect, colType); + + Map.Add(ae, new ColumnExpression(ae.Type, colType, ae.GroupByAlias, name)); + + aggColumns.Add(cd); + } + + return new SelectExpression(select.Alias, aggColumns, select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); + } + + return select; + } + + protected override Expression VisitAggregateSubquery(AggregateSubqueryExpression expression) + { + if (Map.TryGetValue(expression, out Expression mapped)) + return mapped; + + return Visit(expression.AggregateAsSubquery); + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/BindResultRewriter.cs b/Connected.Expressions/Translation/Rewriters/BindResultRewriter.cs new file mode 100644 index 0000000..8762860 --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/BindResultRewriter.cs @@ -0,0 +1,21 @@ +using Connected.Expressions.Expressions; +using System.Collections.ObjectModel; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class BindResultRewriter +{ + public BindResultRewriter(IEnumerable columns, IEnumerable orderings) + { + Columns = columns as ReadOnlyCollection; + + Columns ??= new List(columns).AsReadOnly(); + + Orderings = orderings as ReadOnlyCollection; + + Orderings ??= new List(orderings).AsReadOnly(); + } + + public ReadOnlyCollection? Columns { get; private set; } + public ReadOnlyCollection? Orderings { get; private set; } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/ComparisonRewriter.cs b/Connected.Expressions/Translation/Rewriters/ComparisonRewriter.cs new file mode 100644 index 0000000..eeb37f2 --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/ComparisonRewriter.cs @@ -0,0 +1,135 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Mappings; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class ComparisonRewriter : DatabaseVisitor +{ + private ComparisonRewriter() + { + } + + public static Expression Rewrite(Expression expression) + { + return new ComparisonRewriter().Visit(expression); + } + + protected override Expression VisitBinary(BinaryExpression b) + { + switch (b.NodeType) + { + case ExpressionType.Equal: + case ExpressionType.NotEqual: + var result = Compare(b); + + if (result == b) + goto default; + + return Visit(result); + default: + return base.VisitBinary(b); + } + } + + private static Expression SkipConvert(Expression expression) + { + while (expression.NodeType == ExpressionType.Convert) + expression = ((UnaryExpression)expression).Operand; + + return expression; + } + + private Expression Compare(BinaryExpression bop) + { + var e1 = SkipConvert(bop.Left); + var e2 = SkipConvert(bop.Right); + var oj1 = e1 as OuterJoinedExpression; + var oj2 = e2 as OuterJoinedExpression; + var entity1 = oj1 is not null ? oj1.Expression as EntityExpression : e1 as EntityExpression; + var entity2 = oj2 is not null ? oj2.Expression as EntityExpression : e2 as EntityExpression; + var negate = bop.NodeType == ExpressionType.NotEqual; + + if (oj1 is not null && e2.NodeType == ExpressionType.Constant && ((ConstantExpression)e2).Value is null) + return MakeIsNull(oj1.Test, negate); + else if (oj2 is not null && e1.NodeType == ExpressionType.Constant && ((ConstantExpression)e1).Value is null) + return MakeIsNull(oj2.Test, negate); + + if (entity1 is not null) + return MakePredicate(e1, e2, MappingsCache.Get(entity1.EntityType).Members.Where(f => f.IsPrimaryKey).Select(f => f.MemberInfo), negate); + else if (entity2 is not null) + return MakePredicate(e1, e2, MappingsCache.Get(entity2.EntityType).Members.Where(f => f.IsPrimaryKey).Select(f => f.MemberInfo), negate); + + var dm1 = GetDefinedMembers(e1); + var dm2 = GetDefinedMembers(e2); + + if (dm1 is null && dm2 is null) + return bop; + + if (dm1 is not null && dm2 is not null) + { + var names1 = new HashSet(dm1.Select(m => m.Name)); + var names2 = new HashSet(dm2.Select(m => m.Name)); + + if (names1.IsSubsetOf(names2) && names2.IsSubsetOf(names1)) + return MakePredicate(e1, e2, dm1, negate); + } + else if (dm1 is not null) + return MakePredicate(e1, e2, dm1, negate); + else if (dm2 is not null) + return MakePredicate(e1, e2, dm2, negate); + + throw new InvalidOperationException("Cannot compare two constructed types with different sets of members assigned."); + } + + private static Expression MakeIsNull(Expression expression, bool negate) + { + var isnull = new IsNullExpression(expression); + + return negate ? Expression.Not(isnull) : isnull; + } + + private static Expression? MakePredicate(Expression e1, Expression e2, IEnumerable members, bool negate) + { + var pred = members.Select(m => Binder.Bind(e1, m).Equal(Binder.Bind(e2, m))).Join(ExpressionType.And); + + if (negate) + pred = Expression.Not(pred); + + return pred; + } + + private static IEnumerable GetDefinedMembers(Expression expr) + { + var mini = expr as MemberInitExpression; + + if (mini is not null) + { + var members = mini.Bindings.Select(b => FixMember(b.Member)); + + if (mini.NewExpression.Members is not null) + members = members.Concat(mini.NewExpression.Members.Select(FixMember)); + + return members; + } + else + { + var nex = expr as NewExpression; + + if (nex is not null && nex.Members is not null) + return nex.Members.Select(FixMember); + } + + return null; + } + + private static MemberInfo FixMember(MemberInfo member) + { + if (member is MethodInfo && member.Name.StartsWith("get_")) + return member.DeclaringType.GetTypeInfo().GetDeclaredProperty(member.Name[4..]); + + return member; + } +} diff --git a/Connected.Expressions/Translation/Rewriters/CrossApplyRewriter.cs b/Connected.Expressions/Translation/Rewriters/CrossApplyRewriter.cs new file mode 100644 index 0000000..5d191a4 --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/CrossApplyRewriter.cs @@ -0,0 +1,68 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation.Projections; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class CrossApplyRewriter : DatabaseVisitor +{ + private CrossApplyRewriter(QueryLanguage language) + { + Language = language; + } + + private QueryLanguage Language { get; } + + public static Expression Rewrite(QueryLanguage language, Expression expression) + { + if (new CrossApplyRewriter(language).Visit(expression) is not Expression crossApplyExpression) + throw new NullReferenceException(nameof(crossApplyExpression)); + + return crossApplyExpression; + } + + protected override Expression VisitJoin(JoinExpression expression) + { + expression = (JoinExpression)base.VisitJoin(expression); + + if (expression.Join == JoinType.CrossApply || expression.Join == JoinType.OuterApply) + { + if (expression.Right is TableExpression) + return new JoinExpression(JoinType.CrossJoin, expression.Left, expression.Right, null); + else + { + var select = expression.Right as SelectExpression; + + if (select is not null && select.Take is null && select.Skip is null && !AggregateChecker.HasAggregates(select) && (select.GroupBy is null || !select.GroupBy.Any())) + { + var selectWithoutWhere = select.SetWhere(null); + var referencedAliases = ReferencedAliasesResolver.Resolve(selectWithoutWhere); + var declaredAliases = DeclaredAliasesResolver.Resolve(expression.Left); + + referencedAliases.IntersectWith(declaredAliases); + + if (!referencedAliases.Any()) + { + var where = select.Where; + + select = selectWithoutWhere; + + var pc = ColumnProjector.ProjectColumns(Language, where, select.Columns, select.Alias, DeclaredAliasesResolver.Resolve(select.From)); + + select = select.SetColumns(pc.Columns); + where = pc.Projector; + + var jt = (where == null) ? JoinType.CrossJoin : (expression.Join == JoinType.CrossApply ? JoinType.InnerJoin : JoinType.LeftOuter); + + return new JoinExpression(jt, expression.Left, select, where); + } + } + } + } + + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/CrossJoinRewriter.cs b/Connected.Expressions/Translation/Rewriters/CrossJoinRewriter.cs new file mode 100644 index 0000000..370728a --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/CrossJoinRewriter.cs @@ -0,0 +1,78 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class CrossJoinRewriter : DatabaseVisitor +{ + public static Expression Rewrite(Expression expression) + { + if (new CrossJoinRewriter().Visit(expression) is not Expression crossJoinExpression) + throw new NullReferenceException(nameof(crossJoinExpression)); + + return crossJoinExpression; + } + + private Expression? CurrentWhere { get; set; } + + protected override Expression VisitSelect(SelectExpression select) + { + var saveWhere = CurrentWhere; + + try + { + CurrentWhere = select.Where; + + var result = (SelectExpression)base.VisitSelect(select); + + if (CurrentWhere != result.Where) + return result.SetWhere(CurrentWhere); + + return result; + } + finally + { + CurrentWhere = saveWhere; + } + } + + protected override Expression VisitJoin(JoinExpression expression) + { + expression = (JoinExpression)base.VisitJoin(expression); + + if (expression.Join == JoinType.CrossJoin && CurrentWhere is not null) + { + var declaredLeft = DeclaredAliasesResolver.Resolve(expression.Left); + var declaredRight = DeclaredAliasesResolver.Resolve(expression.Right); + var declared = new HashSet(declaredLeft.Union(declaredRight)); + var exprs = CurrentWhere.Split(ExpressionType.And, ExpressionType.AndAlso); + var good = exprs.Where(e => CanBeJoinCondition(e, declaredLeft, declaredRight, declared)).ToList(); + + if (good.Any()) + { + if (good.Join(ExpressionType.And) is not Expression conditionExpression) + throw new NullReferenceException(nameof(conditionExpression)); + + expression = UpdateJoin(expression, JoinType.InnerJoin, expression.Left, expression.Right, conditionExpression); + + var newWhere = exprs.Where(e => !good.Contains(e)).Join(ExpressionType.And); + + CurrentWhere = newWhere; + } + } + + return expression; + } + + private static bool CanBeJoinCondition(Expression expression, HashSet left, HashSet right, HashSet all) + { + var referenced = ReferencedAliasesResolver.Resolve(expression); + var leftOkay = referenced.Intersect(left).Any(); + var rightOkay = referenced.Intersect(right).Any(); + var subset = referenced.IsSubsetOf(all); + + return leftOkay && rightOkay && subset; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/OrderByRewriter.cs b/Connected.Expressions/Translation/Rewriters/OrderByRewriter.cs new file mode 100644 index 0000000..bc3b3b5 --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/OrderByRewriter.cs @@ -0,0 +1,218 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation.Resolvers; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class OrderByRewriter : DatabaseVisitor +{ + private OrderByRewriter(QueryLanguage language) + { + Language = language; + IsOuterMostSelect = true; + } + + private QueryLanguage Language { get; } + private IList? ResolvedOrderings { get; set; } + private bool IsOuterMostSelect { get; set; } + + public static Expression Rewrite(QueryLanguage language, Expression expression) + { + if (new OrderByRewriter(language).Visit(expression) is not Expression orderByExpression) + throw new NullReferenceException(nameof(orderByExpression)); + + return orderByExpression; + } + + protected override Expression VisitSelect(SelectExpression select) + { + bool saveIsOuterMostSelect = IsOuterMostSelect; + + try + { + IsOuterMostSelect = false; + select = (SelectExpression)base.VisitSelect(select); + + var hasOrderBy = select.OrderBy is not null && select.OrderBy.Count > 0; + var hasGroupBy = select.GroupBy is not null && select.GroupBy.Count > 0; + var canHaveOrderBy = saveIsOuterMostSelect || select.Take is not null || select.Skip is not null; + var canReceiveOrderings = canHaveOrderBy && !hasGroupBy && !select.IsDistinct && !AggregateChecker.HasAggregates(select); + + if (hasOrderBy) + PrependOrderings(select.OrderBy); + + if (select.IsReverse) + ReverseOrderings(); + + IEnumerable? orderings = null; + + if (canReceiveOrderings) + orderings = ResolvedOrderings; + + else if (canHaveOrderBy) + orderings = select.OrderBy; + + var canPassOnOrderings = !saveIsOuterMostSelect && !hasGroupBy && !select.IsDistinct; + var columns = select.Columns; + + if (ResolvedOrderings is not null) + { + if (canPassOnOrderings) + { + var producedAliases = DeclaredAliasesResolver.Resolve(select.From); + var project = RebindOrderings(ResolvedOrderings, select.Alias, producedAliases, select.Columns); + + ResolvedOrderings = null; + + PrependOrderings(project.Orderings); + + columns = project.Columns; + } + else + ResolvedOrderings = null; + } + if (orderings != select.OrderBy || columns != select.Columns || select.IsReverse) + select = new SelectExpression(select.Alias, columns, select.From, select.Where, orderings, select.GroupBy, select.IsDistinct, select.Skip, select.Take, false); + + return select; + } + finally + { + IsOuterMostSelect = saveIsOuterMostSelect; + } + } + + protected override Expression VisitSubquery(SubqueryExpression subquery) + { + var saveOrderings = ResolvedOrderings; + + ResolvedOrderings = null; + + var result = base.VisitSubquery(subquery); + + ResolvedOrderings = saveOrderings; + + return result; + } + + protected override Expression VisitJoin(JoinExpression join) + { + var left = VisitSource(join.Left); + var leftOrders = ResolvedOrderings; + /* + * start on the right with a clean slate + */ + ResolvedOrderings = null; + + var right = VisitSource(join.Right); + + PrependOrderings(leftOrders); + + var condition = Visit(join.Condition); + + if (left != join.Left || right != join.Right || condition != join.Condition) + return new JoinExpression(join.Join, left, right, condition); + + return join; + } + + private void PrependOrderings(IList? newOrderings) + { + if (newOrderings is not null) + { + ResolvedOrderings ??= new List(); + + for (var i = newOrderings.Count - 1; i >= 0; i--) + ResolvedOrderings.Insert(0, newOrderings[i]); + + var unique = new HashSet(); + + for (var i = 0; i < ResolvedOrderings.Count;) + { + if (ResolvedOrderings[i].Expression is ColumnExpression column) + { + var hash = $"{column.Alias}:{column.Name}"; + + if (unique.Contains(hash)) + { + ResolvedOrderings.RemoveAt(i); + + continue; + } + else + unique.Add(hash); + } + + i++; + } + } + } + + private void ReverseOrderings() + { + if (ResolvedOrderings is not null) + { + for (var i = 0; i < ResolvedOrderings.Count; i++) + { + var ord = ResolvedOrderings[i]; + + ResolvedOrderings[i] = new OrderExpression(ord.OrderType == OrderType.Ascending ? OrderType.Descending : OrderType.Ascending, ord.Expression); + } + } + } + + private BindResultRewriter RebindOrderings(IEnumerable orderings, Alias alias, HashSet existingAliases, IEnumerable existingColumns) + { + List? newColumns = null; + List newOrderings = new(); + + foreach (var ordering in orderings) + { + var expr = ordering.Expression; + var column = expr as ColumnExpression; + + if (column is null || (existingAliases is not null && existingAliases.Contains(column.Alias))) + { + var ordinal = 0; + + foreach (var existingColumn in existingColumns) + { + var declColumn = existingColumn.Expression as ColumnExpression; + + if (existingColumn.Expression == ordering.Expression || (column is not null && declColumn is not null && column.Alias == declColumn.Alias && column.Name == declColumn.Name)) + { + expr = new ColumnExpression(column.Type, column.QueryType, alias, existingColumn.Name); + + break; + } + + ordinal++; + } + if (expr == ordering.Expression) + { + if (newColumns is null) + { + newColumns = new List(existingColumns); + existingColumns = newColumns; + } + + var colName = column != null ? column.Name : $"c{ordinal}"; + + colName = newColumns.ResolveAvailableColumnName(colName); + + var colType = Language.TypeSystem.ResolveColumnType(expr.Type); + + newColumns.Add(new ColumnDeclaration(colName, ordering.Expression, colType)); + + expr = new ColumnExpression(expr.Type, colType, alias, colName); + } + + newOrderings.Add(new OrderExpression(ordering.OrderType, expr)); + } + } + + return new BindResultRewriter(existingColumns, newOrderings); + } +} diff --git a/Connected.Expressions/Translation/Rewriters/ParameterRewriter.cs b/Connected.Expressions/Translation/Rewriters/ParameterRewriter.cs new file mode 100644 index 0000000..cb3dbf2 --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/ParameterRewriter.cs @@ -0,0 +1,33 @@ +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public class ParameterRewriter : DatabaseVisitor +{ + private ParameterRewriter(ExpressionCompilationContext context) + { + Context = context; + } + + private ExpressionCompilationContext Context { get; } + + public static Expression Rewrite(ExpressionCompilationContext context, Expression expression) + { + return new ParameterRewriter(context).Visit(expression); + } + + protected override Expression VisitBinary(BinaryExpression expression) + { + return base.VisitBinary(expression); + } + protected override Expression VisitConstant(ConstantExpression expression) + { + var parameter = Context.Parameters.FirstOrDefault(f => f.Value == expression); + + if (parameter.Value is not null) + return Expression.Constant($"@{parameter.Key}"); + + return base.VisitConstant(expression); + } +} diff --git a/Connected.Expressions/Translation/Rewriters/SkipToRowNumberRewriter.cs b/Connected.Expressions/Translation/Rewriters/SkipToRowNumberRewriter.cs new file mode 100644 index 0000000..c2f541a --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/SkipToRowNumberRewriter.cs @@ -0,0 +1,62 @@ +using Connected.Expressions.Expressions; +using Connected.Expressions.Languages; +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public sealed class SkipToRowNumberRewriter : DatabaseVisitor +{ + private SkipToRowNumberRewriter(QueryLanguage language) + { + Language = language; + } + + private QueryLanguage Language { get; set; } + + public static Expression Rewrite(QueryLanguage language, Expression? expression) + { + if (new SkipToRowNumberRewriter(language).Visit(expression) is not Expression skipToRowExpression) + throw new NullReferenceException(nameof(skipToRowExpression)); + + return skipToRowExpression; + } + + protected override Expression VisitSelect(SelectExpression expression) + { + expression = (SelectExpression)base.VisitSelect(expression); + + if (expression.Skip is not null) + { + var newSelect = expression.SetSkip(null).SetTake(null); + var canAddColumn = !expression.IsDistinct && (expression.GroupBy is null || !expression.GroupBy.Any()); + + if (!canAddColumn) + newSelect = newSelect.AddRedundantSelect(Language, Alias.New()); + + var colType = Language.TypeSystem.ResolveColumnType(typeof(int)); + + newSelect = newSelect.AddColumn(new ColumnDeclaration("_rownum", new RowNumberExpression(expression.OrderBy), colType)); + newSelect = newSelect.AddRedundantSelect(Language, Alias.New()); + newSelect = newSelect.RemoveColumn(newSelect.Columns.Single(c => c.Name == "_rownum")); + + var newAlias = ((SelectExpression)newSelect.From).Alias; + var rnCol = new ColumnExpression(typeof(int), colType, newAlias, "_rownum"); + + Expression where; + + if (expression.Take is not null) + where = new BetweenExpression(rnCol, Expression.Add(expression.Skip, Expression.Constant(1)), Expression.Add(expression.Skip, expression.Take)); + else + where = rnCol.GreaterThan(expression.Skip); + + if (newSelect.Where != null) + where = newSelect.Where.And(where); + + newSelect = newSelect.SetWhere(where); + expression = newSelect; + } + + return expression; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs b/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs new file mode 100644 index 0000000..8f33eed --- /dev/null +++ b/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs @@ -0,0 +1,24 @@ +using Connected.Expressions.Visitors; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation.Rewriters; + +public class WhereClauseRewriter : DatabaseVisitor +{ + private WhereClauseRewriter(ExpressionCompilationContext context) + { + Context = context; + } + + public ExpressionCompilationContext Context { get; } + + public static Expression Rewrite(ExpressionCompilationContext context, Expression expression) + { + return new WhereClauseRewriter(context).Visit(expression); + } + + protected override Expression VisitWhere(Expression whereExpression) + { + return ParameterRewriter.Rewrite(Context, whereExpression); + } +} diff --git a/Connected.Expressions/Translation/Translator.cs b/Connected.Expressions/Translation/Translator.cs new file mode 100644 index 0000000..7f2d6f7 --- /dev/null +++ b/Connected.Expressions/Translation/Translator.cs @@ -0,0 +1,68 @@ +using Connected.Expressions.Evaluation; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation.Optimization; +using Connected.Expressions.Translation.Rewriters; +using System.Linq.Expressions; + +namespace Connected.Expressions.Translation; + +public class Translator +{ + + /// + /// Constructs a new . + /// + public Translator(ExpressionCompilationContext context) + { + Context = context; + Linguist = Context.Language.CreateLinguist(context, this); + } + + public Linguist Linguist { get; } + public ExpressionCompilationContext Context { get; } + + /// + /// Translates a query expression using rules defined by the , and . + /// + public Expression? Translate(Expression expression) + { + var result = expression; + /* + * pre-evaluate local sub-trees + */ + result = PartialEvaluator.Eval(Context, result); + /* + * apply mapping (binds LINQ operators too) + */ + result = Bind(Context, result); + /* + * any language specific translations or validations + */ + return Linguist.Translate(result); + } + + private Expression Bind(ExpressionCompilationContext context, Expression expression) + { + var bound = Binder.Bind(context, expression); + var aggmoved = AggregateRewriter.Rewrite(context, bound); + var reduced = UnusedColumns.Remove(aggmoved); + + reduced = RedundantColumns.Remove(reduced); + reduced = RedundantSubqueries.Remove(reduced); + reduced = RedundantJoins.Remove(reduced); + + var rbound = RelationshipBinder.Bind(context, reduced); + + if (rbound != reduced) + { + rbound = RedundantColumns.Remove(rbound); + rbound = RedundantJoins.Remove(rbound); + } + + var result = ComparisonRewriter.Rewrite(rbound); + + //result = WhereClauseRewriter.Rewrite(context, result); + + return result; + } +} \ No newline at end of file diff --git a/Connected.Expressions/Translation/TypeValue.cs b/Connected.Expressions/Translation/TypeValue.cs new file mode 100644 index 0000000..0f2e442 --- /dev/null +++ b/Connected.Expressions/Translation/TypeValue.cs @@ -0,0 +1,33 @@ +namespace Connected.Expressions.Translation; + +internal readonly struct TypeValue : IEquatable +{ + private readonly Type _type; + private readonly object _value; + private readonly int _hash; + + public TypeValue(Type type, object value) + { + _type = type; + _value = value; + _hash = type.GetHashCode() + (value is not null ? value.GetHashCode() : 0); + } + + public override bool Equals(object? obj) + { + if (obj is not TypeValue) + return false; + + return Equals((TypeValue)obj); + } + + public bool Equals(TypeValue vt) + { + return vt._type == _type && Equals(vt._value, _value); + } + + public override int GetHashCode() + { + return _hash; + } +} \ No newline at end of file diff --git a/Connected.Expressions/TypeSystem/QueryTypeSystem.cs b/Connected.Expressions/TypeSystem/QueryTypeSystem.cs new file mode 100644 index 0000000..454ea07 --- /dev/null +++ b/Connected.Expressions/TypeSystem/QueryTypeSystem.cs @@ -0,0 +1,12 @@ +using Connected.Expressions.Languages; + +namespace Connected.Expressions.TypeSystem; + +public abstract class QueryTypeSystem +{ + public abstract DataType Parse(string typeDeclaration); + + public abstract DataType ResolveColumnType(Type type); + + public abstract string Format(DataType type, bool suppressSize); +} diff --git a/Connected.Expressions/Visitors/DatabaseVisitor.cs b/Connected.Expressions/Visitors/DatabaseVisitor.cs new file mode 100644 index 0000000..23f4068 --- /dev/null +++ b/Connected.Expressions/Visitors/DatabaseVisitor.cs @@ -0,0 +1,554 @@ +using Connected.Expressions.Translation; +using System.Collections.ObjectModel; +using System.Linq.Expressions; + +namespace Connected.Expressions.Visitors; + +public abstract class DatabaseVisitor : ExpressionVisitor +{ + protected override Expression? Visit(Expression? expression) + { + if (expression is null) + return default; + + return (DatabaseExpressionType)expression.NodeType switch + { + DatabaseExpressionType.Table => VisitTable((TableExpression)expression), + DatabaseExpressionType.Column => VisitColumn((ColumnExpression)expression), + DatabaseExpressionType.Select => VisitSelect((SelectExpression)expression), + DatabaseExpressionType.Join => VisitJoin((JoinExpression)expression), + DatabaseExpressionType.OuterJoined => VisitOuterJoined((OuterJoinedExpression)expression), + DatabaseExpressionType.Aggregate => VisitAggregate((AggregateExpression)expression), + DatabaseExpressionType.Scalar or DatabaseExpressionType.Exists or DatabaseExpressionType.In => VisitSubquery((SubqueryExpression)expression), + DatabaseExpressionType.AggregateSubquery => VisitAggregateSubquery((AggregateSubqueryExpression)expression), + DatabaseExpressionType.IsNull => VisitIsNull((IsNullExpression)expression), + DatabaseExpressionType.Between => VisitBetween((BetweenExpression)expression), + DatabaseExpressionType.RowCount => VisitRowNumber((RowNumberExpression)expression), + DatabaseExpressionType.Projection => VisitProjection((ProjectionExpression)expression), + DatabaseExpressionType.NamedValue => VisitNamedValue((NamedValueExpression)expression), + DatabaseExpressionType.ClientJoin => VisitClientJoin((ClientJoinExpression)expression), + DatabaseExpressionType.If or DatabaseExpressionType.Block or DatabaseExpressionType.Declaration => VisitCommand((CommandExpression)expression), + DatabaseExpressionType.Batch => VisitBatch((BatchExpression)expression), + DatabaseExpressionType.Variable => VisitVariable((VariableExpression)expression), + DatabaseExpressionType.Function => VisitFunction((FunctionExpression)expression), + DatabaseExpressionType.Entity => VisitEntity((EntityExpression)expression), + _ => base.Visit(expression), + }; + } + + protected virtual Expression VisitEntity(EntityExpression entity) + { + if (Visit(entity.Expression) is not Expression entityExpression) + throw new NullReferenceException(nameof(entityExpression)); + + return UpdateEntity(entity, entityExpression); + } + + protected static EntityExpression UpdateEntity(EntityExpression entity, Expression expression) + { + if (expression != entity.Expression) + return new EntityExpression(entity.EntityType, expression); + + return entity; + } + + protected virtual Expression VisitTable(TableExpression expression) + { + return expression; + } + + protected virtual Expression VisitColumn(ColumnExpression expression) + { + return expression; + } + + protected virtual Expression VisitSelect(SelectExpression expression) + { + var from = VisitSource(expression.From); + var where = VisitWhere(expression.Where); + var groupBy = VisitExpressionList(expression.GroupBy); + var skip = Visit(expression.Skip); + var take = Visit(expression.Take); + var columns = VisitColumnDeclarations(expression.Columns); + var orderBy = VisitOrderBy(expression.OrderBy); + + return UpdateSelect(expression, from, where, orderBy, groupBy, skip, take, expression.IsDistinct, expression.IsReverse, columns); + } + + protected virtual Expression VisitWhere(Expression whereExpression) + { + return whereExpression; + } + protected static SelectExpression UpdateSelect(SelectExpression expression, Expression from, Expression? where, + IEnumerable? orderBy, IEnumerable groupBy, Expression? skip, Expression? take, + bool isDistinct, bool isReverse, IEnumerable columns) + { + if (from != expression.From || where != expression.Where || orderBy != expression.OrderBy || groupBy != expression.GroupBy + || take != expression.Take || skip != expression.Skip || isDistinct != expression.IsDistinct + || columns != expression.Columns || isReverse != expression.IsReverse) + { + return new SelectExpression(expression.Alias, columns, from, where, orderBy, groupBy, isDistinct, skip, take, isReverse); + } + + return expression; + } + + protected virtual Expression VisitJoin(JoinExpression expression) + { + if (Visit(expression.Condition) is not Expression condition) + throw new NullReferenceException(nameof(condition)); + + return UpdateJoin(expression, expression.Join, VisitSource(expression.Left), VisitSource(expression.Right), condition); + } + + protected static JoinExpression UpdateJoin(JoinExpression expression, JoinType joinType, Expression left, Expression right, Expression condition) + { + if (joinType != expression.Join || left != expression.Left || right != expression.Right || condition != expression.Condition) + return new JoinExpression(joinType, left, right, condition); + + return expression; + } + + protected virtual Expression VisitOuterJoined(OuterJoinedExpression expression) + { + if (Visit(expression.Test) is not Expression joinTest) + throw new NullReferenceException(nameof(joinTest)); + + if (Visit(expression.Expression) is not Expression joinExpression) + throw new NullReferenceException(nameof(JoinExpression)); + + return UpdateOuterJoined(expression, joinTest, joinExpression); + } + + protected static OuterJoinedExpression UpdateOuterJoined(OuterJoinedExpression expression, Expression test, Expression e) + { + if (test != expression.Test || e != expression.Expression) + return new OuterJoinedExpression(test, e); + + return expression; + } + + protected virtual Expression VisitAggregate(AggregateExpression expression) + { + if (Visit(expression.Argument) is not Expression argumentExpression) + throw new NullReferenceException(nameof(argumentExpression)); + + return UpdateAggregate(expression, expression.Type, expression.AggregateName, argumentExpression, expression.IsDistinct); + } + + protected static AggregateExpression UpdateAggregate(AggregateExpression expression, Type type, string aggType, Expression e, bool isDistinct) + { + if (type != expression.Type || aggType != expression.AggregateName || e != expression.Argument || isDistinct != expression.IsDistinct) + return new AggregateExpression(type, aggType, e, isDistinct); + + return expression; + } + + protected virtual Expression VisitIsNull(IsNullExpression expression) + { + if (Visit(expression.Expression) is not Expression nullExpression) + throw new NullReferenceException(nameof(nullExpression)); + + return UpdateIsNull(expression, nullExpression); + } + + protected static IsNullExpression UpdateIsNull(IsNullExpression expression, Expression e) + { + if (e != expression.Expression) + return new IsNullExpression(e); + + return expression; + } + + protected virtual Expression VisitBetween(BetweenExpression expression) + { + if (Visit(expression.Expression) is not Expression betweenExpression) + throw new NullReferenceException(nameof(betweenExpression)); + + if (Visit(expression.Lower) is not Expression lowerExpression) + throw new NullReferenceException(nameof(lowerExpression)); + + if (Visit(expression.Upper) is not Expression upperExpression) + throw new NullReferenceException(nameof(upperExpression)); + + return UpdateBetween(expression, betweenExpression, lowerExpression, upperExpression); + } + + protected static BetweenExpression UpdateBetween(BetweenExpression expression, Expression e, Expression lower, Expression upper) + { + if (e != expression.Expression || lower != expression.Lower || upper != expression.Upper) + return new BetweenExpression(e, lower, upper); + + return expression; + } + + protected virtual Expression VisitRowNumber(RowNumberExpression expression) + { + return UpdateRowNumber(expression, VisitOrderBy(expression.OrderBy)); + } + + protected static RowNumberExpression UpdateRowNumber(RowNumberExpression expression, IEnumerable? orderBy) + { + if (orderBy != expression.OrderBy) + { + if (orderBy is null) + throw new ArgumentNullException(nameof(orderBy)); + + return new RowNumberExpression(orderBy); + } + + return expression; + } + + protected virtual Expression VisitNamedValue(NamedValueExpression expression) + { + return expression; + } + + protected virtual Expression VisitSubquery(SubqueryExpression expression) + { + return (DatabaseExpressionType)expression.NodeType switch + { + DatabaseExpressionType.Scalar => VisitScalar((ScalarExpression)expression), + DatabaseExpressionType.Exists => VisitExists((ExistsExpression)expression), + DatabaseExpressionType.In => VisitIn((InExpression)expression), + _ => expression, + }; + } + + protected virtual Expression VisitScalar(ScalarExpression expression) + { + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + return UpdateScalar(expression, selectExpression); + } + + protected static ScalarExpression UpdateScalar(ScalarExpression expression, SelectExpression select) + { + if (select != expression.Select) + return new ScalarExpression(expression.Type, select); + + return expression; + } + + protected virtual Expression VisitExists(ExistsExpression expression) + { + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + return UpdateExists(expression, selectExpression); + } + + protected static ExistsExpression UpdateExists(ExistsExpression expression, SelectExpression select) + { + if (select != expression.Select) + return new ExistsExpression(select); + + return expression; + } + + protected virtual Expression VisitIn(InExpression expression) + { + if (Visit(expression.Expression) is not Expression inExpression) + throw new NullReferenceException(nameof(inExpression)); + + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + return UpdateIn(expression, inExpression, selectExpression, VisitExpressionList(expression.Values)); + } + + protected static InExpression UpdateIn(InExpression expression, Expression e, SelectExpression select, IEnumerable values) + { + if (e != expression.Expression || select != expression.Select || values != expression.Values) + { + if (select is not null) + return new InExpression(e, select); + else + return new InExpression(e, values); + } + + return expression; + } + + protected virtual Expression VisitAggregateSubquery(AggregateSubqueryExpression expression) + { + if (Visit(expression.AggregateAsSubquery) is not ScalarExpression scalarExpression) + throw new NullReferenceException(nameof(scalarExpression)); + + return UpdateAggregateSubquery(expression, scalarExpression); + } + + protected static AggregateSubqueryExpression UpdateAggregateSubquery(AggregateSubqueryExpression expression, ScalarExpression subquery) + { + if (subquery != expression.AggregateAsSubquery) + return new AggregateSubqueryExpression(expression.GroupByAlias, expression.AggregateInGroupSelect, subquery); + + return expression; + } + + protected virtual Expression VisitSource(Expression expression) + { + if (Visit(expression) is not Expression sourceExpression) + throw new NullReferenceException(nameof(sourceExpression)); + + return sourceExpression; + } + + protected virtual Expression VisitProjection(ProjectionExpression expression) + { + if (Visit(expression.Select) is not SelectExpression selectExpression) + throw new NullReferenceException(nameof(selectExpression)); + + if (Visit(expression.Projector) is not Expression projectorExpression) + throw new NullReferenceException(nameof(projectorExpression)); + + return UpdateProjection(expression, selectExpression, projectorExpression, expression.Aggregator); + } + + protected static ProjectionExpression UpdateProjection(ProjectionExpression expression, SelectExpression select, Expression projector, LambdaExpression? aggregator) + { + if (select != expression.Select || projector != expression.Projector || aggregator != expression.Aggregator) + return new ProjectionExpression(select, projector, aggregator); + + return expression; + } + + protected virtual Expression VisitClientJoin(ClientJoinExpression expression) + { + if (Visit(expression.Projection) is not ProjectionExpression projectionExpression) + throw new NullReferenceException(nameof(projectionExpression)); + + return UpdateClientJoin(expression, projectionExpression, VisitExpressionList(expression.OuterKey), VisitExpressionList(expression.InnerKey)); + } + + protected static ClientJoinExpression UpdateClientJoin(ClientJoinExpression expression, ProjectionExpression projection, IEnumerable outerKey, IEnumerable innerKey) + { + if (projection != expression.Projection || outerKey != expression.OuterKey || innerKey != expression.InnerKey) + return new ClientJoinExpression(projection, outerKey, innerKey); + + return expression; + } + + protected virtual Expression VisitCommand(CommandExpression expression) + { + switch ((DatabaseExpressionType)expression.NodeType) + { + case DatabaseExpressionType.If: + return VisitIf((IfCommandExpression)expression); + case DatabaseExpressionType.Block: + return VisitBlock((BlockExpression)expression); + case DatabaseExpressionType.Declaration: + return VisitDeclaration((DeclarationExpression)expression); + default: + if (VisitUnknown(expression) is not Expression unknownExpression) + throw new NullReferenceException(nameof(unknownExpression)); + + return unknownExpression; + } + } + + protected virtual Expression VisitBatch(BatchExpression expression) + { + if (Visit(expression.Operation) is not LambdaExpression lambdaExpression) + throw new NullReferenceException(nameof(lambdaExpression)); + + if (Visit(expression.BatchSize) is not Expression batchExpression) + throw new NullReferenceException(nameof(batchExpression)); + + if (Visit(expression.Stream) is not Expression streamExpression) + throw new NullReferenceException(nameof(streamExpression)); + + return UpdateBatch(expression, expression.Input, lambdaExpression, batchExpression, streamExpression); + } + + protected static BatchExpression UpdateBatch(BatchExpression expression, Expression input, LambdaExpression operation, Expression batchSize, Expression stream) + { + if (input != expression.Input || operation != expression.Operation || batchSize != expression.BatchSize || stream != expression.Stream) + return new BatchExpression(input, operation, batchSize, stream); + + return expression; + } + + protected virtual Expression VisitIf(IfCommandExpression command) + { + if (Visit(command.Check) is not Expression checkExpression) + throw new NullReferenceException(nameof(checkExpression)); + + if (Visit(command.Check) is not Expression ifTrueExpression) + throw new NullReferenceException(nameof(ifTrueExpression)); + + if (Visit(command.Check) is not Expression ifFalseExpression) + throw new NullReferenceException(nameof(ifFalseExpression)); + + return UpdateIf(command, checkExpression, ifTrueExpression, ifFalseExpression); + } + + protected static IfCommandExpression UpdateIf(IfCommandExpression command, Expression check, Expression ifTrue, Expression ifFalse) + { + if (check != command.Check || ifTrue != command.IfTrue || ifFalse != command.IfFalse) + return new IfCommandExpression(check, ifTrue, ifFalse); + + return command; + } + + protected virtual Expression VisitBlock(BlockExpression command) + { + return UpdateBlock(command, VisitExpressionList(command.Commands)); + } + + protected static BlockExpression UpdateBlock(BlockExpression command, IList commands) + { + if (command.Commands != commands) + return new BlockExpression(commands); + + return command; + } + + protected virtual Expression VisitDeclaration(DeclarationExpression command) + { + if (Visit(command.Source) is not SelectExpression sourceExpression) + throw new NullReferenceException(nameof(sourceExpression)); + + return UpdateDeclaration(command, VisitVariableDeclarations(command.Variables), sourceExpression); + } + + protected static DeclarationExpression UpdateDeclaration(DeclarationExpression command, IEnumerable variables, SelectExpression source) + { + if (variables != command.Variables || source != command.Source) + return new DeclarationExpression(variables, source); + + return command; + } + + protected virtual Expression VisitVariable(VariableExpression expression) + { + return expression; + } + + protected virtual Expression VisitFunction(FunctionExpression expression) + { + return UpdateFunction(expression, expression.Name, VisitExpressionList(expression.Arguments)); + } + + protected static FunctionExpression UpdateFunction(FunctionExpression expression, string name, IEnumerable arguments) + { + if (name != expression.Name || arguments != expression.Arguments) + return new FunctionExpression(expression.Type, name, arguments); + + return expression; + } + + protected virtual ColumnAssignment VisitColumnAssignment(ColumnAssignment column) + { + if (Visit(column.Column) is not ColumnExpression columnExpression) + throw new NullReferenceException(nameof(columnExpression)); + + if (Visit(column.Expression) is not Expression expression) + throw new NullReferenceException(nameof(expression)); + + return UpdateColumnAssignment(column, columnExpression, expression); + } + + protected static ColumnAssignment UpdateColumnAssignment(ColumnAssignment column, ColumnExpression c, Expression e) + { + if (c != column.Column || e != column.Expression) + return new ColumnAssignment(c, e); + + return column; + } + + protected virtual ReadOnlyCollection VisitColumnAssignments(ReadOnlyCollection assignments) + { + List? alternate = null; + + for (var i = 0; i < assignments.Count; i++) + { + var current = assignments[i]; + var assignment = VisitColumnAssignment(current); + + if (alternate is null && assignment != current) + alternate = assignments.Take(i).ToList(); + + alternate?.Add(assignment); + } + + if (alternate is not null) + return alternate.AsReadOnly(); + + return assignments; + } + + protected virtual ReadOnlyCollection VisitColumnDeclarations(ReadOnlyCollection columns) + { + List? alternate = null; + + for (var i = 0; i < columns.Count; i++) + { + var column = columns[i]; + + if (Visit(column.Expression) is not Expression columnDeclarationExpression) + throw new NullReferenceException(nameof(columnDeclarationExpression)); + + if (alternate is null && columnDeclarationExpression != column.Expression) + alternate = columns.Take(i).ToList(); + + alternate?.Add(new ColumnDeclaration(column.Name, columnDeclarationExpression, column.DataType)); + } + + if (alternate is not null) + return alternate.AsReadOnly(); + + return columns; + } + + protected virtual ReadOnlyCollection VisitVariableDeclarations(ReadOnlyCollection declarations) + { + List? alternate = null; + + for (var i = 0; i < declarations.Count; i++) + { + var decl = declarations[i]; + + if (Visit(decl.Expression) is not Expression declarationExpression) + throw new NullReferenceException(nameof(declarationExpression)); + + if (alternate is null && declarationExpression != decl.Expression) + alternate = declarations.Take(i).ToList(); + + alternate?.Add(new VariableDeclaration(decl.Name, decl.DataType, declarationExpression)); + } + + if (alternate is not null) + return alternate.AsReadOnly(); + + return declarations; + } + + protected virtual ReadOnlyCollection? VisitOrderBy(ReadOnlyCollection? expressions) + { + if (expressions is not null) + { + List? alternate = null; + + for (var i = 0; i < expressions.Count; i++) + { + var expr = expressions[i]; + + if (Visit(expr.Expression) is not Expression orderByExpression) + throw new NullReferenceException(nameof(orderByExpression)); + + if (alternate is null && orderByExpression != expr.Expression) + alternate = expressions.Take(i).ToList(); + + alternate?.Add(new OrderExpression(expr.OrderType, orderByExpression)); + } + + if (alternate is not null) + return alternate.AsReadOnly(); + } + + return expressions; + } +} diff --git a/Connected.Expressions/Visitors/ExpressionVisitor.cs b/Connected.Expressions/Visitors/ExpressionVisitor.cs new file mode 100644 index 0000000..ceaf491 --- /dev/null +++ b/Connected.Expressions/Visitors/ExpressionVisitor.cs @@ -0,0 +1,478 @@ +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Expressions.Visitors; + +public abstract class ExpressionVisitor : IDisposable +{ + protected bool IsDisposed { get; private set; } + + protected virtual Expression? Visit(Expression? expression) + { + if (expression is null) + return default; + + return expression.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Not or ExpressionType.Convert or ExpressionType.ConvertChecked + or ExpressionType.ArrayLength or ExpressionType.Quote or ExpressionType.TypeAs or ExpressionType.UnaryPlus => VisitUnary((UnaryExpression)expression), + ExpressionType.Add or ExpressionType.AddChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked or ExpressionType.Multiply + or ExpressionType.MultiplyChecked or ExpressionType.Divide or ExpressionType.Modulo or ExpressionType.And or ExpressionType.AndAlso + or ExpressionType.Or or ExpressionType.OrElse or ExpressionType.LessThan or ExpressionType.LessThanOrEqual or ExpressionType.GreaterThan + or ExpressionType.GreaterThanOrEqual or ExpressionType.Equal or ExpressionType.NotEqual or ExpressionType.Coalesce or ExpressionType.ArrayIndex + or ExpressionType.RightShift or ExpressionType.LeftShift or ExpressionType.ExclusiveOr or ExpressionType.Power => VisitBinary((BinaryExpression)expression), + ExpressionType.TypeIs => VisitTypeIs((TypeBinaryExpression)expression), + ExpressionType.Conditional => VisitConditional((ConditionalExpression)expression), + ExpressionType.Constant => VisitConstant((ConstantExpression)expression), + ExpressionType.Parameter => VisitParameter((ParameterExpression)expression), + ExpressionType.MemberAccess => VisitMemberAccess((MemberExpression)expression), + ExpressionType.Call => VisitMethodCall((MethodCallExpression)expression), + ExpressionType.Lambda => VisitLambda((LambdaExpression)expression), + ExpressionType.New => VisitNew((NewExpression)expression), + ExpressionType.NewArrayInit or ExpressionType.NewArrayBounds => VisitNewArray((NewArrayExpression)expression), + ExpressionType.Invoke => VisitInvocation((InvocationExpression)expression), + ExpressionType.MemberInit => VisitMemberInit((MemberInitExpression)expression), + ExpressionType.ListInit => VisitListInit((ListInitExpression)expression), + _ => VisitUnknown(expression), + }; + } + + protected virtual Expression VisitUnary(UnaryExpression expression) + { + if (Visit(expression.Operand) is not Expression visited) + throw new NullReferenceException(nameof(UnaryExpression)); + + return UpdateUnary(expression, visited, expression.Type, expression.Method); + } + + protected static UnaryExpression UpdateUnary(UnaryExpression expression, Expression operand, Type resultType, MethodInfo? method) + { + if (expression.Operand != operand || expression.Type != resultType || expression.Method != method) + return Expression.MakeUnary(expression.NodeType, operand, resultType, method); + + return expression; + } + + protected virtual Expression VisitBinary(BinaryExpression expression) + { + if (Visit(expression.Left) is not Expression left) + throw new NullReferenceException(nameof(left)); + + if (Visit(expression.Right) is not Expression right) + throw new NullReferenceException(nameof(right)); + + var conversion = Visit(expression.Conversion); + + return UpdateBinary(expression, left, right, conversion, expression.IsLiftedToNull, expression.Method); + } + + protected static BinaryExpression UpdateBinary(BinaryExpression expression, Expression left, Expression right, Expression? conversion, bool isLiftedToNull, MethodInfo? method) + { + if (left != expression.Left || right != expression.Right || conversion != expression.Conversion || method != expression.Method || isLiftedToNull != expression.IsLiftedToNull) + { + if (expression.NodeType == ExpressionType.Coalesce && expression.Conversion is not null) + return Expression.Coalesce(left, right, conversion as LambdaExpression); + else + return Expression.MakeBinary(expression.NodeType, left, right, isLiftedToNull, method); + } + + return expression; + } + + protected virtual Expression VisitTypeIs(TypeBinaryExpression expression) + { + if (Visit(expression.Expression) is not Expression visited) + throw new NullReferenceException(nameof(visited)); + + return UpdateTypeIs(expression, visited, expression.TypeOperand); + } + + protected static TypeBinaryExpression UpdateTypeIs(TypeBinaryExpression expression, Expression e, Type typeOperand) + { + if (e != expression.Expression || typeOperand != expression.TypeOperand) + return Expression.TypeIs(e, typeOperand); + + return expression; + } + + protected virtual Expression VisitConditional(ConditionalExpression expression) + { + if (Visit(expression.Test) is not Expression test) + throw new NullReferenceException(nameof(test)); + + if (Visit(expression.IfTrue) is not Expression ifTrue) + throw new NullReferenceException(nameof(ifTrue)); + + if (Visit(expression.IfFalse) is not Expression ifFalse) + throw new NullReferenceException(nameof(ifFalse)); + + return UpdateConditional(expression, test, ifTrue, ifFalse); + } + + protected static ConditionalExpression UpdateConditional(ConditionalExpression expression, Expression test, Expression ifTrue, Expression ifFalse) + { + if (test != expression.Test || ifTrue != expression.IfTrue || ifFalse != expression.IfFalse) + return Expression.Condition(test, ifTrue, ifFalse); + + return expression; + } + + protected virtual Expression VisitConstant(ConstantExpression expression) + { + return expression; + } + + protected virtual Expression VisitParameter(ParameterExpression expression) + { + return expression; + } + + protected virtual Expression VisitMemberAccess(MemberExpression expression) + { + if (Visit(expression.Expression) is not Expression member) + throw new NullReferenceException(nameof(member)); + + return UpdateMemberAccess(expression, member, expression.Member); + } + + protected static MemberExpression UpdateMemberAccess(MemberExpression expression, Expression e, MemberInfo member) + { + if (e != expression.Expression || member != expression.Member) + return Expression.MakeMemberAccess(e, member); + + return expression; + } + + protected virtual Expression? VisitMethodCall(MethodCallExpression expression) + { + return UpdateMethodCall(expression, Visit(expression.Object), expression.Method, VisitExpressionList(expression.Arguments)); + } + + protected static MethodCallExpression UpdateMethodCall(MethodCallExpression expression, Expression? e, MethodInfo method, IEnumerable args) + { + if (e != expression.Object || method != expression.Method || args != expression.Arguments) + return Expression.Call(e, method, args); + + return expression; + } + + protected virtual Expression VisitLambda(LambdaExpression lambda) + { + if (Visit(lambda.Body) is not Expression body) + throw new NullReferenceException(nameof(body)); + + return UpdateLambda(lambda, lambda.Type, body, lambda.Parameters); + } + + protected static LambdaExpression UpdateLambda(LambdaExpression lambda, Type delegateType, Expression body, IEnumerable parameters) + { + if (body != lambda.Body || parameters != lambda.Parameters || delegateType != lambda.Type) + return Expression.Lambda(delegateType, body, parameters); + + return lambda; + } + + protected virtual NewExpression VisitNew(NewExpression expression) + { + return UpdateNew(expression, expression.Constructor, VisitMemberAndExpressionList(expression.Members, expression.Arguments), expression.Members); + } + + protected static NewExpression UpdateNew(NewExpression expression, ConstructorInfo? constructor, IEnumerable args, IEnumerable? members) + { + if (args != expression.Arguments || constructor != expression.Constructor || members != expression.Members) + { + if (constructor is null) + throw new NullReferenceException(nameof(constructor)); + + if (expression.Members is not null) + return Expression.New(constructor, args, members); + else + return Expression.New(constructor, args); + } + + return expression; + } + + protected virtual Expression VisitInvocation(InvocationExpression expression) + { + if (Visit(expression.Expression) is not Expression invocation) + throw new NullReferenceException(nameof(invocation)); + + return UpdateInvocation(expression, invocation, VisitExpressionList(expression.Arguments)); + } + + protected static InvocationExpression UpdateInvocation(InvocationExpression expression, Expression e, IEnumerable args) + { + if (args != expression.Arguments || e != expression.Expression) + return Expression.Invoke(e, args); + + return expression; + } + + protected virtual Expression VisitMemberInit(MemberInitExpression expression) + { + return UpdateMemberInit(expression, VisitNew(expression.NewExpression), VisitBindingList(expression.Bindings)); + } + + protected static MemberInitExpression UpdateMemberInit(MemberInitExpression init, NewExpression e, IEnumerable bindings) + { + if (e != init.NewExpression || bindings != init.Bindings) + return Expression.MemberInit(e, bindings); + + return init; + } + + protected virtual Expression VisitListInit(ListInitExpression expression) + { + return UpdateListInit(expression, VisitNew(expression.NewExpression), VisitElementInitializerList(expression.Initializers)); + } + + protected static ListInitExpression UpdateListInit(ListInitExpression expression, NewExpression e, IEnumerable initializers) + { + if (e != expression.NewExpression || initializers != expression.Initializers) + return Expression.ListInit(e, initializers); + + return expression; + } + + protected virtual ReadOnlyCollection VisitMemberAndExpressionList(ReadOnlyCollection? members, ReadOnlyCollection? expressions) + { + if (expressions is null) + return new ReadOnlyCollection(new List()); + + List? result = null; + + for (int i = 0; i < expressions.Count; i++) + { + var current = expressions[i]; + var visited = VisitMemberAndExpression(members?[i], expressions[i]); + + if (visited is null) + continue; + + if (result is not null) + result.Add(visited); + else if (visited != current) + { + result = new List(expressions.Count); + + for (var j = 0; j < i; j++) + result.Add(expressions[j]); + + result.Add(visited); + } + } + + if (result is not null) + return result.AsReadOnly(); + + return expressions; + } + + protected virtual ReadOnlyCollection VisitExpressionList(ReadOnlyCollection? expressions) + { + if (expressions is null) + return new ReadOnlyCollection(new List()); + + List? result = null; + + for (var i = 0; i < expressions.Count; i++) + { + var current = expressions[i]; + var visited = Visit(current); + + if (visited is null) + continue; + + if (result is not null) + result.Add(visited); + else if (visited != current) + { + result = new List(expressions.Count); + + for (var j = 0; j < i; j++) + result.Add(expressions[j]); + + result.Add(visited); + } + } + + if (result is not null) + return result.AsReadOnly(); + + return expressions; + } + + protected virtual IEnumerable VisitBindingList(ReadOnlyCollection bindings) + { + List? result = null; + + for (int i = 0; i < bindings.Count; i++) + { + var current = bindings[i]; + var visited = VisitBinding(current); + + if (result is not null) + result.Add(visited); + else if (visited != current) + { + result = new List(bindings.Count); + + for (var j = 0; j < i; j++) + result.Add(bindings[j]); + + result.Add(visited); + } + } + + if (result is not null) + return result; + + return bindings; + } + + protected virtual IEnumerable VisitElementInitializerList(ReadOnlyCollection elements) + { + List? result = null; + + for (int i = 0; i < elements.Count; i++) + { + var current = elements[i]; + var visited = VisitElementInitializer(current); + + if (result is not null) + result.Add(visited); + else if (visited != current) + { + result = new List(elements.Count); + + for (var j = 0; j < i; j++) + result.Add(elements[j]); + + result.Add(visited); + } + } + + if (result is not null) + return result; + + return elements; + } + + protected virtual MemberBinding VisitBinding(MemberBinding binding) + { + return binding.BindingType switch + { + MemberBindingType.Assignment => VisitMemberAssignment((MemberAssignment)binding), + MemberBindingType.MemberBinding => VisitMemberMemberBinding((MemberMemberBinding)binding), + MemberBindingType.ListBinding => VisitMemberListBinding((MemberListBinding)binding), + _ => throw new NotSupportedException($"Unhandled binding type '{binding.BindingType}'"), + }; + } + + protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) + { + if (Visit(assignment.Expression) is not Expression assignmentExpression) + throw new NullReferenceException(nameof(assignmentExpression)); + + return UpdateMemberAssignment(assignment, assignment.Member, assignmentExpression); + } + + protected static MemberAssignment UpdateMemberAssignment(MemberAssignment assignment, MemberInfo member, Expression expression) + { + if (expression != assignment.Expression || member != assignment.Member) + return Expression.Bind(member, expression); + + return assignment; + } + + protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) + { + return UpdateMemberMemberBinding(binding, binding.Member, VisitBindingList(binding.Bindings)); + } + + protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) + { + return UpdateMemberListBinding(binding, binding.Member, VisitElementInitializerList(binding.Initializers)); + } + + protected static MemberListBinding UpdateMemberListBinding(MemberListBinding binding, MemberInfo member, IEnumerable initializers) + { + if (initializers != binding.Initializers || member != binding.Member) + return Expression.ListBind(member, initializers); + + return binding; + } + + protected static MemberMemberBinding UpdateMemberMemberBinding(MemberMemberBinding binding, MemberInfo member, IEnumerable bindings) + { + if (bindings != binding.Bindings || member != binding.Member) + return Expression.MemberBind(member, bindings); + + return binding; + } + + protected virtual Expression VisitNewArray(NewArrayExpression expression) + { + return UpdateNewArray(expression, expression.Type, VisitExpressionList(expression.Expressions)); + } + + protected static NewArrayExpression UpdateNewArray(NewArrayExpression expression, Type arrayType, IEnumerable expressions) + { + if (expressions != expression.Expressions || expression.Type != arrayType) + { + if (arrayType.GetElementType() is not Type elementType) + throw new NullReferenceException(nameof(elementType)); + + if (expression.NodeType == ExpressionType.NewArrayInit) + return Expression.NewArrayInit(elementType, expressions); + else + return Expression.NewArrayBounds(elementType, expressions); + } + + return expression; + } + + protected virtual Expression? VisitUnknown(Expression expression) + { + throw new NotSupportedException(expression.ToString()); + } + + protected virtual Expression? VisitMemberAndExpression(MemberInfo? member, Expression expression) + { + return Visit(expression); + } + + protected virtual ElementInit VisitElementInitializer(ElementInit initializer) + { + var arguments = VisitExpressionList(initializer.Arguments); + + if (arguments != initializer.Arguments) + return Expression.ElementInit(initializer.AddMethod, arguments); + + return initializer; + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + OnDisposing(); + + IsDisposed = true; + } + } + + protected virtual void OnDisposing() + { + + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Globalization/Connected.Globalization.csproj b/Connected.Globalization/Connected.Globalization.csproj new file mode 100644 index 0000000..0f0775d --- /dev/null +++ b/Connected.Globalization/Connected.Globalization.csproj @@ -0,0 +1,14 @@ + + + + net7.0 + enable + enable + + + + + + + + diff --git a/Connected.Globalization/GlobalizationRoutes.cs b/Connected.Globalization/GlobalizationRoutes.cs new file mode 100644 index 0000000..63d2085 --- /dev/null +++ b/Connected.Globalization/GlobalizationRoutes.cs @@ -0,0 +1,7 @@ +namespace Connected.Globalization +{ + public static class GlobalizationRoutes + { + public const string Languages = "/globalization/languages"; + } +} diff --git a/Connected.Globalization/IGlobalizationService.cs b/Connected.Globalization/IGlobalizationService.cs new file mode 100644 index 0000000..9637cdb --- /dev/null +++ b/Connected.Globalization/IGlobalizationService.cs @@ -0,0 +1,54 @@ +using System.Globalization; +using Connected.Globalization.Languages; + +namespace Connected.Globalization; + +/// +/// The service used for globalization related purposes. +/// +public interface IGlobalizationService +{ + /// + /// Converts an UTC date to a local one base on the current identity and timezone. + /// + /// The value to be converted. + /// a value converted to local time. + DateTime FromUtc(DateTime value); + /// + /// Converts a local date to an UTC one base on the current identity and timezone. + /// + /// The value to be converted. + /// a value converted to UTC time. + DateTime ToUtc(DateTime value); + /// + /// Converts an UTC date to a local one base on the current identity and timezone. + /// + /// The value to be converted. + /// a value converted to local time. + DateTimeOffset FromUtc(DateTimeOffset value); + /// + /// Converts a local date to an UTC one base on the current identity and timezone. + /// + /// The value to be converted. + /// a value converted to UTC time. + DateTimeOffset ToUtc(DateTimeOffset value); + /// + /// Returns currently used timezone which is based on the current identity. + /// + TimeZoneInfo TimeZone { get; } + /// + /// Returns the converted value to a currently set . + /// + DateTimeOffset Now { get; } + /// + /// Returns the currently active language based on a user identity. + /// + /// The if user has a language set, null otherwise. + Task GetCurrentLanguage(); + /// + /// Returns the currently active based on a user identity. + /// + /// The if user has a language set, + /// otherwise. + Task GetCurrentCulture(); +} diff --git a/Connected.Globalization/Languages/ILanguage.cs b/Connected.Globalization/Languages/ILanguage.cs new file mode 100644 index 0000000..313572a --- /dev/null +++ b/Connected.Globalization/Languages/ILanguage.cs @@ -0,0 +1,39 @@ +using Connected.Data; + +namespace Connected.Globalization.Languages; + +/// +/// Represents a language used by localization services in the environment. +/// +/// +/// In a multilanguage enviroments this entity defines a set of languages +/// supported by the environment. Language contains LCID which is used when +/// choosing which localization strings are loaded in the request. +/// Language is typically set by user's identity thus enabling environment +/// globalization. +/// +public interface ILanguage : IPrimaryKey +{ + /// + /// The name of the language. This is a descriptive property and does + /// not have any special meaning when mapping globalization services. + /// + string Name { get; } + /// + /// The language status. If the language is + /// ignored even user has a language set to its identity. + /// + Status Status { get; } + /// + /// The locale id associated with a language. This value is used when a language + /// is set as a request language and is + /// set during the request pipeline. + /// + int Lcid { get; } + /// + /// A set of mapping strings used when the environment is resolving a language from a browser. + /// Browser sends supported languages and if identity is not set en evrionment tries to resolve + /// a correct language via this property. Use ',' delimiter when setting multiple mappings. + /// + string Mappings { get; } +} diff --git a/Connected.Globalization/Languages/ILanguageService.cs b/Connected.Globalization/Languages/ILanguageService.cs new file mode 100644 index 0000000..34e4e2f --- /dev/null +++ b/Connected.Globalization/Languages/ILanguageService.cs @@ -0,0 +1,71 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.Notifications; +using Connected.ServiceModel; + +namespace Connected.Globalization.Languages; + +[Service] +[ServiceUrl(GlobalizationRoutes.Languages)] +public interface ILanguageService : IServiceNotifications +{ + /// + /// Returns all valid languages from the environment. + /// + /// A list of available languages. + [ServiceMethod(ServiceMethodVerbs.Get)] + Task> Query(); + /// + /// Performs a query on entities for the specified list of ids. + /// + /// The List of the ids for which the perform query. + /// An of entities that matches + /// the passed ids. + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task?> Query(PrimaryKeyListArgs e); + /// + /// Selects a language by its id. + /// + /// containing the language id for which an entity should be returned. + /// entity if found, null otherwise. + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + /// + /// Selects a language by its name. + /// + /// containing a language name for which an entity should be returned. + /// entity if found, null otherwise. + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(NameArgs args); + /// + /// Resolves a language by mapping criteria. + /// + /// containing mapping criteria. + /// + /// Mapping criteria is split into tokens, separated by ',' character. Each string (mapping) is then searched in the + /// property of each supported and language. + /// + /// The first that has at least one mapping, null if no language meets + /// the criteria. + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Resolve(LanguageResolveArgs args); + /// + /// Inserts a new language. + /// + /// arguments containing data about a new language. + /// An of the newly added language. + [ServiceMethod(ServiceMethodVerbs.Post)] + Task Insert(LanguageInsertArgs args); + /// + /// Updates existing language. + /// + /// arguments with modified values. + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Patch)] + Task Update(LanguageUpdateArgs args); + /// + /// Deletes language from the environment. + /// + /// containing an id of the language to be deleted. + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Globalization/Languages/LanguageArgs.cs b/Connected.Globalization/Languages/LanguageArgs.cs new file mode 100644 index 0000000..74f8d63 --- /dev/null +++ b/Connected.Globalization/Languages/LanguageArgs.cs @@ -0,0 +1,46 @@ +using System.ComponentModel.DataAnnotations; +using Connected; +using Connected.Notifications; + +namespace Connected.Globalization.Languages; + +/// +/// The arguments used when inserting a new . +/// +public class LanguageInsertArgs : IEventArgs +{ + /// + /// The language name. + /// + [Required] + [MaxLength(128)] + public string? Name { get; init; } + /// + /// The locale id of the language. This should point to one of the + /// supported .NET Core supported languages. + /// + public int Lcid { get; init; } + /// + /// The language mappings, separated with ',' character. + /// + public string? Mappings { get; init; } +} +/// +/// The arguments used when updating a language. +/// +public sealed class LanguageUpdateArgs : LanguageInsertArgs +{ + public int Id { get; init; } +} +/// +/// The arguments used when resolving a lnaguge by its mapping. +/// +public class LanguageResolveArgs : IDto +{ + /// + /// The mapping string whose values will be compared with a mapping + /// property of the language. Use ',' separator when using multiple + /// mappings. + /// + public string? Mapping { get; init; } +} diff --git a/Connected.Hosting/Connected.Hosting.csproj b/Connected.Hosting/Connected.Hosting.csproj new file mode 100644 index 0000000..cc5f1ee --- /dev/null +++ b/Connected.Hosting/Connected.Hosting.csproj @@ -0,0 +1,29 @@ + + + + net7.0 + enable + enable + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Hosting/SR.Designer.cs b/Connected.Hosting/SR.Designer.cs new file mode 100644 index 0000000..15ceb68 --- /dev/null +++ b/Connected.Hosting/SR.Designer.cs @@ -0,0 +1,63 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Hosting { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Hosting.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + } +} diff --git a/Connected.Hosting/SR.resx b/Connected.Hosting/SR.resx new file mode 100644 index 0000000..1af7de1 --- /dev/null +++ b/Connected.Hosting/SR.resx @@ -0,0 +1,120 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + \ No newline at end of file diff --git a/Connected.Hosting/Workers/IWorker.cs b/Connected.Hosting/Workers/IWorker.cs new file mode 100644 index 0000000..cb79a13 --- /dev/null +++ b/Connected.Hosting/Workers/IWorker.cs @@ -0,0 +1,8 @@ +using Microsoft.Extensions.Hosting; + +namespace Connected.Hosting.Workers +{ + public interface IWorker : IHostedService + { + } +} diff --git a/Connected.Hosting/Workers/ScheduledWorker.cs b/Connected.Hosting/Workers/ScheduledWorker.cs new file mode 100644 index 0000000..b1953f9 --- /dev/null +++ b/Connected.Hosting/Workers/ScheduledWorker.cs @@ -0,0 +1,33 @@ +namespace Connected.Hosting.Workers +{ + public abstract class ScheduledWorker : Worker + { + protected virtual TimeSpan Timer { get; set; } = TimeSpan.FromMinutes(1); + protected long Count { get; private set; } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + do + { + try + { + await OnInvoke(stoppingToken); + } + catch (Exception ex) + { + await OnError(ex); + } + finally + { + Count++; + } + + if (Timer == TimeSpan.Zero) + break; + + await Task.Delay(Timer, stoppingToken); + } + while (!stoppingToken.IsCancellationRequested); + } + } +} diff --git a/Connected.Hosting/Workers/Worker.cs b/Connected.Hosting/Workers/Worker.cs new file mode 100644 index 0000000..adfc472 --- /dev/null +++ b/Connected.Hosting/Workers/Worker.cs @@ -0,0 +1,25 @@ +using Microsoft.Extensions.Hosting; + +namespace Connected.Hosting.Workers +{ + /// + /// This component acts as a background worker. + /// + public abstract class Worker : BackgroundService, IWorker + { + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + await OnInvoke(stoppingToken); + } + + protected virtual async Task OnInvoke(CancellationToken cancellationToken) + { + await Task.CompletedTask; + } + + protected virtual async Task OnError(Exception ex) + { + await Task.CompletedTask; + } + } +} diff --git a/Connected.Instance/Assemblies.cs b/Connected.Instance/Assemblies.cs new file mode 100644 index 0000000..aab1b6b --- /dev/null +++ b/Connected.Instance/Assemblies.cs @@ -0,0 +1,92 @@ +using System.Collections.Immutable; +using System.Reflection; +using Connected.Annotations; +using Connected.Collections; + +namespace Connected.Instance; + +public static class Assemblies +{ + private static readonly List _all; + static Assemblies() + { + _all = new(); + + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + if (assembly.GetCustomAttribute() is not null) + _all.Add(assembly); + } + } + + internal static ImmutableList All => _all.ToImmutableList(); + public static Dictionary QueryInterfaces(this Assembly assembly) where TAttribute : Attribute + { + var result = new Dictionary(); + + foreach (var type in assembly.GetTypes()) + { + if (type.IsAbstract || !type.IsClass) + continue; + + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (i.GetCustomAttribute() is not null) + result.Add(i, type); + } + } + + return result; + } + + public static Dictionary QueryImplementations(this Assembly assembly) + { + var result = new Dictionary(); + + foreach (var type in assembly.GetTypes()) + { + if (type.IsAbstract || !type.IsClass) + continue; + + if (type.GetInterface(typeof(T).FullName) is null) + continue; + + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (i.GetInterface(typeof(T).FullName) is not null) + result.Add(i, type); + } + } + + return result; + } + + public static ImmutableList QueryImplementations() + { + var target = typeof(T).FullName; + var result = new List(); + + foreach (var microService in _all) + { + var types = microService.GetTypes(); + + foreach (var type in types) + { + if (type.IsAbstract || type.IsPrimitive || type.IsInterface) + continue; + + if (type.GetInterface(target) is not null) + result.Add(type); + } + } + + result.SortByPriority(); + + return result.ToImmutableList(); + } + +} \ No newline at end of file diff --git a/Connected.Instance/Connected.Instance.csproj b/Connected.Instance/Connected.Instance.csproj new file mode 100644 index 0000000..4a5281a --- /dev/null +++ b/Connected.Instance/Connected.Instance.csproj @@ -0,0 +1,45 @@ + + + + net7.0 + enable + enable + $(MSBuildProjectName.Replace(" ", "_")) + + + + + SR.resx + True + True + + + + + + SR.Designer.cs + ResXFileCodeGenerator + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Connected.Instance/EntitySynchronizer.cs b/Connected.Instance/EntitySynchronizer.cs new file mode 100644 index 0000000..d41c4a3 --- /dev/null +++ b/Connected.Instance/EntitySynchronizer.cs @@ -0,0 +1,112 @@ +using System.Reflection; +using Connected.Configuration.Environment; +using Connected.Data.Schema; +using Connected.Entities; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Connected.Instance; + +internal static class EntitySynchronizer +{ + public static async Task Synchronize(IServiceProvider services, string value) + { + var logger = services.GetService>(); + + if (logger is null) + throw new NullReferenceException(nameof(ILogger)); + + logger.LogTrace("Starting entities synchronization."); + + if (services.GetService() is not IContextProvider provider) + throw new NullReferenceException(nameof(IContextProvider)); + + using var ctx = provider.Create(); + + if (ctx.GetService() is not IEnvironmentService environmentService) + throw new NullReferenceException(nameof(IEnvironmentService)); + + if (ctx.GetService() is not ISchemaService schemaService) + { + logger.LogWarning("ISchemaService is not registered. Entity synchronization skipped."); + return; + } + /* + * entitySynchronization = rebuild + */ + if (string.Equals(value, "rebuild", StringComparison.OrdinalIgnoreCase)) + { + logger.LogTrace("Rebuilding entities."); + + foreach (var assembly in environmentService.MicroServices) + await Synchronize(schemaService, assembly, logger); + } + else + { + /* + * entitySynchronization = token1, token2,... + */ + var tokens = value.Split(','); + /* + * token1: assembly : assemblyName + * token2: type : typeName + */ + foreach (var token in tokens) + { + var subTokens = token.Split(':'); + + if (subTokens.Length != 2) + throw new ArgumentException("Invalid entitySyncronization token '{token}'. Expected assembly:[assembly] or type:[type].", token); + + if (string.Equals(subTokens[0], "assembly", StringComparison.OrdinalIgnoreCase)) + { + logger.LogTrace("Loading assembly '{assembly}'", subTokens[1]); + + await Synchronize(schemaService, Assembly.Load(AssemblyName.GetAssemblyName(subTokens[1])), logger); + } + else if (string.Equals(subTokens[1], "type", StringComparison.OrdinalIgnoreCase)) + { + logger.LogTrace("Loading type '{type}'", subTokens[1]); + + if (Type.GetType(subTokens[1]) is not Type type) + { + logger.LogWarning("Entity type '{type}' could not be loaded. Synchronization on the specified type could not be performed.", subTokens[1]); + continue; + } + + await Synchronize(schemaService, new List { type }); + } + } + } + + logger.LogTrace("Commiting snchronization."); + + if (ctx.GetService() is ITransactionContext transaction) + await transaction.Commit(); + + logger.LogTrace("Snchronization completed."); + } + + private static async Task Synchronize(ISchemaService schemaService, Assembly assembly, ILogger logger) + { + var entities = new List(); + + foreach (var type in assembly.GetTypes()) + { + if (type.IsAbstract || !type.IsAssignableTo(typeof(IEntity))) + continue; + + entities.Add(type); + } + + if (entities.Any()) + await Synchronize(schemaService, entities); + } + + private static async Task Synchronize(ISchemaService schemaService, List entities) + { + await schemaService.Synchronize(entities); + } +} diff --git a/Connected.Instance/Instance.cs b/Connected.Instance/Instance.cs new file mode 100644 index 0000000..8373921 --- /dev/null +++ b/Connected.Instance/Instance.cs @@ -0,0 +1,64 @@ +using Connected.Interop; +using Microsoft.AspNetCore.Builder; + +namespace Connected.Instance +{ + public static class Instance + { + internal static WebApplication Host { get; private set; } + + internal static async Task StartAsync(Dictionary args) + { + var builder = WebApplication.CreateBuilder(UnpackArguments(args)); + var startups = Assemblies.QueryImplementations(); + + foreach (var assembly in Assemblies.All) + builder.Services.AddMicroService(assembly); + + foreach (var startup in startups) + { + if (startup.CreateInstance() is IStartup start) + start.ConfigureServices(builder.Services); + } + + Host = builder.Build(); + + foreach (var startup in startups) + { + if (startup.CreateInstance() is IStartup start) + start.Configure(Host); + } + + foreach (var startup in startups) + { + if (startup.CreateInstance() is IStartup start) + await start.Initialize(args); + } + + foreach (var startup in startups) + { + if (startup.CreateInstance() is IStartup start) + await start.Start(args); + } + + if (args.TryGetValue("entitySynchronization", out string? entities)) + await EntitySynchronizer.Synchronize(Host.Services, entities); + + await Host.RunAsync(); + } + + private static string[] UnpackArguments(Dictionary args) + { + var result = new string[args.Count]; + + for (var i = 0; i < args.Count; i++) + { + var arg = args.ElementAt(i); + + result[i] = string.IsNullOrWhiteSpace(arg.Value) ? arg.Key : $"{arg.Key}={arg.Value}"; + } + + return result; + } + } +} \ No newline at end of file diff --git a/Connected.Instance/SR.Designer.cs b/Connected.Instance/SR.Designer.cs new file mode 100644 index 0000000..ff1bd50 --- /dev/null +++ b/Connected.Instance/SR.Designer.cs @@ -0,0 +1,126 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Instance { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Connected.Instance.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to The specified entity was not found. + /// + internal static string DefaultAssertNullMessage { + get { + return ResourceManager.GetString("DefaultAssertNullMessage", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot resolve database type. + /// + internal static string ErrCannotResolveDatabaseType { + get { + return ResourceManager.GetString("ErrCannotResolveDatabaseType", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot create service instance. + /// + internal static string ErrCreateService { + get { + return ResourceManager.GetString("ErrCreateService", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Type '{0}' does not implement '{1}' interface. + /// + internal static string ErrInterfaceExpected { + get { + return ResourceManager.GetString("ErrInterfaceExpected", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to PrimaryKey argument is null. + /// + internal static string ErrPrimaryKeyNull { + get { + return ResourceManager.GetString("ErrPrimaryKeyNull", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Service already registered. + /// + internal static string ErrServiceRegistered { + get { + return ResourceManager.GetString("ErrServiceRegistered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Transaction already attached. + /// + internal static string ErrTransactionNotNull { + get { + return ResourceManager.GetString("ErrTransactionNotNull", resourceCulture); + } + } + } +} diff --git a/Connected.Instance/SR.resx b/Connected.Instance/SR.resx new file mode 100644 index 0000000..9264990 --- /dev/null +++ b/Connected.Instance/SR.resx @@ -0,0 +1,141 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + The specified entity was not found + + + Cannot resolve database type + + + Cannot create service instance + + + Type '{0}' does not implement '{1}' interface + + + PrimaryKey argument is null + + + Service already registered + + + Transaction already attached + + \ No newline at end of file diff --git a/Connected.Instance/ServerStartup.cs b/Connected.Instance/ServerStartup.cs new file mode 100644 index 0000000..8a68cec --- /dev/null +++ b/Connected.Instance/ServerStartup.cs @@ -0,0 +1,46 @@ +using Connected.Annotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Net.Http.Headers; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Instance +{ + [Priority(int.MaxValue)] + internal sealed class ServerStartup : Startup + { + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddHttpContextAccessor(); + services.AddSignalR(); + services.AddHttpClient(); + services.AddLogging(); + services.AddAntiforgery(); + + services.AddCors((o) => + { + o.AddDefaultPolicy(new Microsoft.AspNetCore.Cors.Infrastructure.CorsPolicy + { + IsOriginAllowed = (a) => { return true; }, + Headers = { HeaderNames.ContentType, HeaderNames.ContentEncoding, HeaderNames.XRequestedWith, "x-signalr-user-agent" } + }); + }); + } + + protected override void OnConfigure(WebApplication app) + { + if (!app.Environment.IsDevelopment()) + { + app.UseExceptionHandler("/Error"); + app.UseHsts(); + app.UseHttpsRedirection(); + } + + app.UseCors(); + app.UseRouting(); + } + } +} diff --git a/Connected.Instance/Start.cs b/Connected.Instance/Start.cs new file mode 100644 index 0000000..4aebec4 --- /dev/null +++ b/Connected.Instance/Start.cs @@ -0,0 +1,10 @@ +namespace Connected.Instance +{ + public static class Start + { + public static async Task ConfigureAsync(Dictionary args) + { + await Instance.StartAsync(args); + } + } +} \ No newline at end of file diff --git a/Connected.Instance/StartupExtensions.cs b/Connected.Instance/StartupExtensions.cs new file mode 100644 index 0000000..344c0c4 --- /dev/null +++ b/Connected.Instance/StartupExtensions.cs @@ -0,0 +1,194 @@ +using Connected.Annotations; +using Connected.Collections.Concurrent; +using Connected.Configuration; +using Connected.Entities.Caching; +using Connected.Hosting.Workers; +using Connected.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using System.Reflection; + +namespace Connected.Instance; + +public static class StartupExtensions +{ + internal static IServiceCollection AddMicroService(this IServiceCollection services, Assembly assembly) + { + foreach (var type in assembly.GetTypes()) + { + if (type.IsAbstract || !type.IsClass) + continue; + + AddService(type, services, false); + AddArgument(type, services, false); + AddServiceOperation(type, services, false); + AddEntityCache(type, services, false); + AddMiddleware(type, services, false); + AddDispatcher(type, services, false); + AddDispatcherJob(type, services, false); + AddHostedService(type, services, false); + } + + return services; + } + + public static void AddServiceOperation(Type type, IServiceCollection services) + { + AddServiceOperation(type, services, true); + } + + private static void AddServiceOperation(Type type, IServiceCollection services, bool manual) + { + if (CanRegister(type, manual) && type.IsServiceOperation()) + { + services.AddTransient(type); + RegisteredServices.AddApi(type); + } + } + + public static void AddService(Type type, IServiceCollection services) + { + AddService(type, services, true); + } + + private static void AddService(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual)) + return; + + if (type.GetInterface(typeof(IService).FullName) is null) + return; + + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (i.GetCustomAttribute() is not null) + { + services.Add(ServiceDescriptor.Scoped(i, type)); + RegisteredServices.AddApiService(type); + } + } + } + + public static void AddArgument(Type type, IServiceCollection services) + { + AddArgument(type, services, true); + } + + private static void AddArgument(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || type.GetInterface(typeof(IDto).FullName) is null) + return; + + services.Add(ServiceDescriptor.Transient(type, type)); + RegisteredServices.AddArgument(type); + } + + public static void AddMiddleware(Type type, IServiceCollection services) + { + AddMiddleware(type, services, true); + } + public static void AddMiddleware(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || type.GetInterface(typeof(IMiddleware).FullName) is null) + return; + + var att = type.GetCustomAttribute(); + var scope = ServiceRegistrationScope.Scoped; + + if (att is not null) + scope = att.Scope; + + switch (scope) + { + case ServiceRegistrationScope.Singleton: + services.Add(ServiceDescriptor.Singleton(type, type)); + break; + case ServiceRegistrationScope.Scoped: + services.Add(ServiceDescriptor.Scoped(type, type)); + break; + case ServiceRegistrationScope.Transient: + services.Add(ServiceDescriptor.Transient(type, type)); + break; + } + + RegisteredServices.AddMiddleware(type); + } + + public static void AddEntityCache(Type type, IServiceCollection services) + { + AddEntityCache(type, services, true); + } + + private static void AddEntityCache(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || typeof(IEntityCacheClient<,>).FullName is not string fullName) + return; + + if (type.GetInterface(fullName) is null) + return; + + foreach (var itf in type.GetInterfaces()) + { + if (itf.GetInterface(fullName) is not null) + { + services.Add(ServiceDescriptor.Singleton(itf, type)); + RegisteredServices.AddEntityCache(type); + } + } + } + + public static void AddDispatcher(Type type, IServiceCollection services) + { + AddDispatcher(type, services, true); + } + + private static void AddDispatcher(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || typeof(IDispatcher<,>).FullName is not string fullName) + return; + + if (type.GetInterface(fullName) is null) + return; + + services.AddTransient(type); + } + + public static void AddDispatcherJob(Type type, IServiceCollection services) + { + AddDispatcherJob(type, services, true); + } + + private static void AddDispatcherJob(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || typeof(IDispatcherJob<>).FullName is not string fullName) + return; + + if (type.GetInterface(fullName) is null) + return; + + services.AddTransient(type); + } + + public static void AddHostedService(Type type, IServiceCollection services) + { + AddHostedService(type, services, true); + } + + private static void AddHostedService(Type type, IServiceCollection services, bool manual) + { + if (!CanRegister(type, manual) || type.GetInterface(typeof(IWorker).FullName) is null) + return; + + services.AddSingleton(typeof(IHostedService), type); + } + + private static bool CanRegister(Type type, bool manual) + { + if (manual) + return true; + + return type.GetCustomAttribute() is not ServiceRegistrationAttribute att || att.Mode == ServiceRegistrationMode.Auto; + } +} diff --git a/Connected.Interop/Annotations/ArgsBindingAttribute.cs b/Connected.Interop/Annotations/ArgsBindingAttribute.cs new file mode 100644 index 0000000..e2d3cae --- /dev/null +++ b/Connected.Interop/Annotations/ArgsBindingAttribute.cs @@ -0,0 +1,8 @@ +using Connected; + +namespace Connected.Interop.Annotations; + +[AttributeUsage(AttributeTargets.Parameter)] +public sealed class ArgsBindingAttribute : Attribute where T : IDto +{ +} diff --git a/Connected.Interop/Connected.Interop.csproj b/Connected.Interop/Connected.Interop.csproj new file mode 100644 index 0000000..4a05ca1 --- /dev/null +++ b/Connected.Interop/Connected.Interop.csproj @@ -0,0 +1,29 @@ + + + + net7.0 + enable + enable + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Interop/Enumerables.cs b/Connected.Interop/Enumerables.cs new file mode 100644 index 0000000..10598d4 --- /dev/null +++ b/Connected.Interop/Enumerables.cs @@ -0,0 +1,72 @@ +using System.Reflection; + +namespace Connected.Interop +{ + public static class Enumerables + { + public static Type? FindEnumerable(this Type type) + { + if (type is null || type == typeof(string)) + return default; + + if (type.IsArray) + { + var elementType = type.GetElementType(); + + if (elementType is not null) + return typeof(IEnumerable<>).MakeGenericType(elementType); + else + return default; + } + + var typeInfo = type.GetTypeInfo(); + + if (typeInfo.IsGenericType) + { + foreach (var arg in typeInfo.GenericTypeArguments) + { + var en = typeof(IEnumerable<>).MakeGenericType(arg); + + if (en.GetTypeInfo().IsAssignableFrom(typeInfo)) + return en; + } + } + + foreach (var itf in typeInfo.ImplementedInterfaces) + { + var en = itf.FindEnumerable(); + + if (en is not null) + return en; + } + + if (typeInfo.BaseType is not null && typeInfo.BaseType != typeof(object)) + return typeInfo.BaseType.FindEnumerable(); + + return default; + } + + public static bool IsEnumerable(this Type type) + { + return type.FindEnumerable() is not null; + } + + public static Type GetEnumerableType(this Type elementType) + { + return typeof(IEnumerable<>).MakeGenericType(elementType); + } + + public static Type? GetEnumerableElementType(this Type? enumerableType) + { + if (enumerableType is null) + return default; + + var en = enumerableType.FindEnumerable(); + + if (en is null) + return enumerableType; + + return en.GetTypeInfo().GenericTypeArguments[0]; + } + } +} diff --git a/Connected.Interop/Generics.cs b/Connected.Interop/Generics.cs new file mode 100644 index 0000000..0b8d41b --- /dev/null +++ b/Connected.Interop/Generics.cs @@ -0,0 +1,22 @@ +namespace Connected.Interop +{ + public static class Generics + { + public static bool IsSubclassOfGenericType(this Type type, Type genericType) + { + var current = type.BaseType; + + while (current is not null && current != typeof(object)) + { + var currentType = current.IsGenericType ? current.GetGenericTypeDefinition() : current; + + if (genericType == currentType) + return true; + + current = current.BaseType; + } + + return false; + } + } +} diff --git a/Connected.Interop/InteropExtensions.cs b/Connected.Interop/InteropExtensions.cs new file mode 100644 index 0000000..64ce8f5 --- /dev/null +++ b/Connected.Interop/InteropExtensions.cs @@ -0,0 +1,27 @@ +using Connected.Notifications; + +namespace Connected.Interop; + +public static class InteropExtensions +{ + public static TArgs AsArguments(this IDto args) where TArgs : IDto + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, args); + } + + public static TArgs AsArguments(this IDto args, params object[] sources) where TArgs : IDto + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, sources.ToArray(), args); + } + + public static TArgs AsEventArguments(this IDto args) where TArgs : IEventArgs + { + var instance = typeof(TArgs).CreateInstance(); + + return Serializer.Merge(instance, args); + } +} diff --git a/Connected.Interop/InteropStartup.cs b/Connected.Interop/InteropStartup.cs new file mode 100644 index 0000000..b6c8a02 --- /dev/null +++ b/Connected.Interop/InteropStartup.cs @@ -0,0 +1,16 @@ +using Connected.Annotations; +using Microsoft.AspNetCore.Builder; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Interop; + +internal class InteropStartup : Startup +{ + public static WebApplication? Application { get; private set; } + + protected override void OnConfigure(WebApplication app) + { + Application = app; + } +} diff --git a/Connected.Interop/Members.cs b/Connected.Interop/Members.cs new file mode 100644 index 0000000..6a1866d --- /dev/null +++ b/Connected.Interop/Members.cs @@ -0,0 +1,57 @@ +using System.Reflection; + +namespace Connected.Interop +{ + public static class Members + { + public static IEnumerable GetDataMembers(this Type type, string? name = null, bool includeNonPublic = false) + { + return type.GetInheritedProperites() + .Where(p => p.CanRead && p.GetMethod is not null && !p.GetMethod.IsStatic && (p.GetMethod.IsPublic || includeNonPublic) && (string.IsNullOrEmpty(name) || string.Equals(p.Name, name, StringComparison.Ordinal))) + .Cast().Concat(type.GetInheritedFields() + .Where(f => !f.IsStatic && (f.IsPublic || includeNonPublic) && (string.IsNullOrEmpty(name) || string.Equals(f.Name, name, StringComparison.Ordinal)))); + } + + public static MemberInfo? GetDataMember(this Type type, string name, bool includeNonPublic = false) + { + return type.GetDataMembers(name, includeNonPublic).FirstOrDefault(); + } + + public static IEnumerable GetInheritedFields(this Type type) + { + foreach (var info in type.GetInheritedTypeInfos()) + { + foreach (var p in info.GetRuntimeFields()) + yield return p; + } + } + + public static Type? GetMemberType(MemberInfo mi) + { + if (mi is FieldInfo fi) + return fi.FieldType; + + if (mi is PropertyInfo pi) + return pi.PropertyType; + + if (mi is EventInfo ei) + return ei.EventHandlerType; + + if (mi is MethodInfo me) + return me.ReturnType; + + return default; + } + + public static bool IsReadOnly(MemberInfo member) + { + if (member is PropertyInfo pi) + return !pi.CanWrite || pi.SetMethod is null; + + if (member is FieldInfo fi) + return (fi.Attributes & FieldAttributes.InitOnly) != 0; + + return true; + } + } +} diff --git a/Connected.Interop/Merging/JsonMerger.cs b/Connected.Interop/Merging/JsonMerger.cs new file mode 100644 index 0000000..51899d0 --- /dev/null +++ b/Connected.Interop/Merging/JsonMerger.cs @@ -0,0 +1,100 @@ +using System.Collections; +using System.Reflection; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Connected.Interop.Merging +{ + internal sealed class JsonMerger : Merger + { + public void Merge(object destination, JsonNode? source) + { + if (source is null || destination is null) + return; + + foreach (var property in Properties.GetImplementedProperties(destination)) + MergeProperty(destination, source, property); + } + + private void MergeProperty(object destination, JsonNode source, PropertyInfo property) + { + if (property.FindAttribute() is not null) + return; + + if (property.PropertyType.IsTypePrimitive()) + { + if (!property.CanWrite || source is not JsonObject jo || ResolveJsonProperty(property, jo) is not JsonValue jprop) + return; + + if (!TypeConversion.TryConvert(jprop.ToString(), out object? convertedValue, property.PropertyType)) + return; + + property.SetValue(destination, convertedValue); + } + else if (IsArray(property)) + MergeEnumerable(destination, source, property); + else + MergeObject(destination, source, property); + } + + private static JsonValue? ResolveJsonProperty(PropertyInfo property, JsonObject json) + { + foreach (var prop in json) + { + if (string.Equals(prop.Key, property.Name, StringComparison.OrdinalIgnoreCase)) + return prop.Value as JsonValue; + } + + return null; + } + + private void MergeEnumerable(object destination, JsonNode source, PropertyInfo property) + { + if (source is not JsonObject jobject) + return; + + if (!jobject.ContainsKey(property.Name) || jobject[property.Name] is not JsonArray array || !array.Any()) + return; + + var value = property.GetValue(destination); + + if (value is null && !property.CanWrite) + return; + + var addMethod = property.PropertyType.GetMethod(nameof(IList.Add), BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic); + var instance = addMethod is not null ? Activator.CreateInstance(property.PropertyType) : Array.CreateInstance(property.PropertyType, array.Count); + var elementType = instance.GetType().GetElementType(); + + for (var i = 0; i < array.Count; i++) + { + var json = array[i]; + var item = Activator.CreateInstance(elementType); + + Merge(item, json); + + if (addMethod is not null) + addMethod.Invoke(instance, new object[] { item }); + else + ((Array)instance).SetValue(item, i); + } + + property.SetValue(destination, instance); + } + + private void MergeObject(object destination, JsonNode source, PropertyInfo property) + { + var value = property.GetValue(destination); + + if (value is null && property.CanWrite) + property.SetValue(destination, property.PropertyType.CreateInstance()); + + value = property.GetValue(destination); + + if (value is null) + return; + + foreach (var instanceProperty in Properties.GetImplementedProperties(value)) + MergeProperty(value, source, instanceProperty); + } + } +} diff --git a/Connected.Interop/Merging/Merger.cs b/Connected.Interop/Merging/Merger.cs new file mode 100644 index 0000000..137eca8 --- /dev/null +++ b/Connected.Interop/Merging/Merger.cs @@ -0,0 +1,24 @@ +using System.Collections; +using System.Reflection; + +namespace Connected.Interop.Merging +{ + internal abstract class Merger + { + protected static bool IsArray(PropertyInfo property) + { + return property.PropertyType.IsArray || property.PropertyType.GetInterface(typeof(IEnumerable).FullName) is not null; + } + + protected static List GetElements(IEnumerable enumerator) + { + var result = new List(); + var en = enumerator.GetEnumerator(); + + while (en.MoveNext()) + result.Add(en.Current); + + return result; + } + } +} diff --git a/Connected.Interop/Merging/ObjectMerger.cs b/Connected.Interop/Merging/ObjectMerger.cs new file mode 100644 index 0000000..c00017b --- /dev/null +++ b/Connected.Interop/Merging/ObjectMerger.cs @@ -0,0 +1,127 @@ +using Connected.ServiceModel; +using System.Collections; +using System.Reflection; + +namespace Connected.Interop.Merging +{ + internal sealed class ObjectMerger : Merger + { + public void Merge(object destination, params object[] sources) + { + if (destination is null || !HasSource(sources)) + return; + + var sourceProperties = AggregateProperties(sources); + + foreach (var property in Properties.GetImplementedProperties(destination)) + MergeProperty(destination, sourceProperties, property); + } + + private static bool HasSource(params object[] sources) + { + foreach (var source in sources) + { + if (source is not null) + return true; + } + + return false; + } + + private static Dictionary AggregateProperties(params object[] sources) + { + var result = new Dictionary(); + + for (var i = sources.Length - 1; i >= 0; i--) + { + var source = sources[i]; + + if (source is null) + continue; + + var props = Properties.GetImplementedProperties(source); + + foreach (var property in props) + { + if (result.ContainsKey(property.Name)) + continue; + + result.Add(property.Name, source); + } + + if (source is IPropertyProvider provider) + { + foreach (var property in provider.Properties) + { + if (result.ContainsKey(property.Key)) + continue; + + result.Add(property.Key, property.Value); + } + } + } + + return result; + } + + private void MergeProperty(object destination, Dictionary sourceProperties, PropertyInfo property) + { + if (property.PropertyType.IsTypePrimitive()) + { + if (!property.CanWrite) + return; + + if (!sourceProperties.TryGetValue(property.Name, out object? source)) + return; + + property.SetValue(destination, source.GetType().GetProperty(property.Name).GetValue(source)); + } + else if (IsArray(property)) + MergeEnumerable(destination, sourceProperties, property); + else + MergeObject(destination, sourceProperties, property); + } + + private void MergeEnumerable(object destination, Dictionary sourceProperties, PropertyInfo property) + { + if (sourceProperties.TryGetValue(property.Name, out object? sourceProperty)) + return; + + if (property.GetValue(sourceProperty) is not IEnumerable sourceValue) + return; + + var sourceElements = GetElements(sourceValue); + var destinationValue = property.GetValue(destination); + + if (destinationValue is null && !property.CanWrite) + return; + + var addMethod = property.PropertyType.GetMethod(nameof(IList.Add), BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic); + var instance = addMethod is not null ? Activator.CreateInstance(property.PropertyType) : Array.CreateInstance(property.PropertyType, sourceElements.Count); + var elementType = instance.GetType().GetElementType(); + + for (var i = 0; i < sourceElements.Count; i++) + { + /* + * TODO: handle Dictionary + */ + var sourceElement = sourceElements[i]; + var item = Activator.CreateInstance(elementType); + + Merge(sourceElement, item); + + if (addMethod is not null) + addMethod.Invoke(instance, new object[] { item }); + else + ((Array)instance).SetValue(item, i); + } + + property.SetValue(destination, instance); + } + + private void MergeObject(object destination, object source, PropertyInfo property) + { + throw new NotImplementedException(); + } + } +} diff --git a/Connected.Interop/Methods.cs b/Connected.Interop/Methods.cs new file mode 100644 index 0000000..7842a85 --- /dev/null +++ b/Connected.Interop/Methods.cs @@ -0,0 +1,127 @@ +using System.Reflection; +using Connected.ServiceModel; + +namespace Connected.Interop; + +public static class Methods +{ + /// + /// Resolves method based on generic arguments and parameter types. + /// + /// The type on which method is declared. + /// The name of the method. + /// The type arguments if a method is a generic method definition. + /// The argument types which method accepts. + /// + public static MethodInfo? ResolveMethod(this Type type, string name, Type[]? typeArguments, Type[]? parameterTypes) + { + var typeArgumentCount = typeArguments is not null ? typeArguments.Length : 0; + /* + * First, get all methods available on the type. + */ + foreach (var method in type.GetInheritedMethods()) + { + /* + * If a method is a generic method and type arguments weren't passed, + * skip this method because we are not interested in it. + */ + if (method.IsGenericMethodDefinition != typeArgumentCount > 0) + continue; + /* + * Also, the name of the method must match, of course. + */ + if (!string.Equals(method.Name, name)) + continue; + /* + * Now check if the type arguments match. + */ + if (method.IsGenericMethodDefinition && typeArguments is not null && typeArguments.Any()) + { + /* + * Check if the number of arguments equals on both the definition and typeArguments argument. + */ + if (method.GetGenericArguments().Length != typeArgumentCount) + continue; + /* + * Now try to create a generic method definition. + */ + var constructedMethod = method.MakeGenericMethod(typeArguments); + /* + * And if parameters also match we have a target. + */ + if (ParametersMatch(constructedMethod.GetParameters(), parameterTypes)) + return constructedMethod; + } + /* + * The method is not a generic method, we're only going to check the parameters types. + */ + if (ParametersMatch(method.GetParameters(), parameterTypes)) + return method; + } + + return default; + } + + internal static bool ParametersMatch(ParameterInfo[] parameters, Type[] parameterTypes) + { + /* + * Parameterless service methods must pass IDto argument into api methods. First check this. + */ + if (parameterTypes is not null && parameterTypes.Length == 1 && parameterTypes[0] == typeof(Dto) && (parameters is null || parameters.Length == 0)) + return true; + + if (parameters.Length != parameterTypes.Length) + return false; + + for (var i = 0; i < parameters.Length; i++) + { + if (!parameters[i].ParameterType.IsAssignableFrom(parameterTypes[i])) + return false; + } + + return true; + } + + public static IEnumerable GetInheritedMethods(this Type type) + { + foreach (var info in type.GetInheritedTypeInfos()) + { + foreach (var p in info.GetRuntimeMethods()) + yield return p; + } + } + + public static async Task InvokeAsync(this MethodInfo method, object component, params object[] parameters) + { + if (method.ReturnType is null) + { + method.Invoke(component, parameters); + + return Task.FromResult(null); + } + else + { + var isAwaitable = method.ReturnType.GetMethod(nameof(Task.GetAwaiter)) is not null; + + if (isAwaitable) + { + if (method.ReturnType.IsGenericType) + return await (dynamic)method.Invoke(component, parameters); + else + { + await (Task)method.Invoke(component, parameters); + return null; + } + } + else + { + if (method.ReturnType == typeof(void)) + method.Invoke(component, parameters); + else + return Task.FromResult(method.Invoke(component, parameters)); + } + } + + return Task.FromResult(null); + } +} diff --git a/Connected.Interop/Nullables.cs b/Connected.Interop/Nullables.cs new file mode 100644 index 0000000..3903c9f --- /dev/null +++ b/Connected.Interop/Nullables.cs @@ -0,0 +1,39 @@ +using System.Linq.Expressions; +using System.Reflection; + +namespace Connected.Interop +{ + public static class Nullables + { + public static bool IsNullableType(this Type type) + { + return type is not null && type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); + } + + public static bool IsNullAssignable(this Type type) + { + return !type.GetTypeInfo().IsValueType || type.IsNullableType(); + } + + public static Type GetNonNullableType(this Type type) + { + if (type.IsNullableType()) + return type.GetTypeInfo().GenericTypeArguments[0]; + + return type; + } + + public static Type GetNullAssignableType(this Type type) + { + if (!type.IsNullAssignable()) + return typeof(Nullable<>).MakeGenericType(type); + + return type; + } + + public static ConstantExpression GetNullConstant(this Type type) + { + return Expression.Constant(null, type.GetNullAssignableType()); + } + } +} diff --git a/Connected.Interop/Properties.cs b/Connected.Interop/Properties.cs new file mode 100644 index 0000000..013c380 --- /dev/null +++ b/Connected.Interop/Properties.cs @@ -0,0 +1,111 @@ +using System.Reflection; +using System.Runtime.CompilerServices; + +namespace Connected.Interop +{ + public static class Properties + { + public static PropertyInfo? GetPropertyAttribute(object instance) where T : Attribute + { + var props = GetProperties(instance, false); + + if (props is null || !props.Any()) + return default; + + foreach (var property in props) + { + if (property.GetCustomAttribute() is not null) + return property; + } + + return default; + } + + public static PropertyInfo[]? GetProperties(object instance, bool writableOnly) + { + if (instance.GetType().GetProperties() is not PropertyInfo[] properties) + return default; + + var temp = new List(); + + foreach (var i in properties) + { + var getMethod = i.GetGetMethod(); + var setMethod = i.GetSetMethod(); + + if (writableOnly && setMethod is null) + continue; + + if (getMethod is null) + continue; + + if (getMethod is not null && getMethod.IsStatic || setMethod is not null && setMethod.IsStatic) + continue; + + if (setMethod is not null && !setMethod.IsPublic) + continue; + + temp.Add(i); + } + + return temp.ToArray(); + } + + public static IEnumerable GetInheritedProperites(this Type type) + { + foreach (var info in type.GetInheritedTypeInfos()) + { + foreach (var p in info.GetRuntimeProperties()) + yield return p; + } + } + + public static List GetImplementedProperties(object component) + { + var type = component is Type ct ? ct : component.GetType(); + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.GetProperty | BindingFlags.SetProperty); + var result = new List(); + + foreach (var property in properties) + { + if (property.GetCustomAttribute() is not null) + continue; + + result.Add(property); + } + + return result; + } + + public static void SetPropertyValue(object instance, string propertyName, object? value) + { + var property = instance.GetType().GetProperty(propertyName); + + if (property is null) + return; + + if (!property.CanWrite) + { + if (property.DeclaringType is null) + return; + + property = property.DeclaringType.GetProperty(propertyName); + } + + if (property is null || property.SetMethod is null) + return; + + property.SetMethod.Invoke(instance, new object[] { value }); + } + + public static T? FindAttribute(this PropertyInfo info) where T : Attribute + { + var atts = info.GetCustomAttributes(true); + + if (atts is null || !atts.Any()) + return default; + + return atts.ElementAt(0); + } + } +} diff --git a/Connected.Interop/Reflection/IStringConcatenator.cs b/Connected.Interop/Reflection/IStringConcatenator.cs new file mode 100644 index 0000000..f4a2e1b --- /dev/null +++ b/Connected.Interop/Reflection/IStringConcatenator.cs @@ -0,0 +1,7 @@ +namespace Connected.Interop.Reflection +{ + internal interface IStringConcatenator + { + string Concatenate(string[] values); + } +} diff --git a/Connected.Interop/Reflection/IStringSplitter.cs b/Connected.Interop/Reflection/IStringSplitter.cs new file mode 100644 index 0000000..05d8c29 --- /dev/null +++ b/Connected.Interop/Reflection/IStringSplitter.cs @@ -0,0 +1,7 @@ +namespace Connected.Interop.Reflection +{ + internal interface IStringSplitter + { + string[] Split(string valueList); + } +} diff --git a/Connected.Interop/Reflection/InvalidConversionException.cs b/Connected.Interop/Reflection/InvalidConversionException.cs new file mode 100644 index 0000000..2ba0836 --- /dev/null +++ b/Connected.Interop/Reflection/InvalidConversionException.cs @@ -0,0 +1,10 @@ +namespace Connected.Interop.Reflection +{ + internal class InvalidConversionException : InvalidOperationException + { + public InvalidConversionException(object valueToConvert, Type destinationType) + : base($"'{valueToConvert}' ({valueToConvert?.GetType()}) is not convertible to '{destinationType}'.") + { + } + } +} diff --git a/Connected.Interop/Reflection/ObjectComparer.cs b/Connected.Interop/Reflection/ObjectComparer.cs new file mode 100644 index 0000000..cdc954f --- /dev/null +++ b/Connected.Interop/Reflection/ObjectComparer.cs @@ -0,0 +1,69 @@ +using System.Collections; +using System.Reflection; + +namespace Connected.Interop.Reflection +{ + public static class ObjectComparer + { + public static bool Compare(object left, object right) + { + if (left is null || right is null) + return false; + + if (left.GetType() != right.GetType()) + return false; + + var leftProperties = left.GetType().GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); ; + + foreach (var property in leftProperties) + { + if (!Compare(left, right, property)) + return false; + } + + return true; + } + private static bool Compare(object left, object right, PropertyInfo property) + { + if (property.PropertyType.IsEnumerable()) + { + if (!Compare(left as IEnumerable, right as IEnumerable)) + return false; + } + else if (property.PropertyType.IsTypePrimitive()) + { + if (Comparer.DefaultInvariant.Compare(left, right) != 0) + return false; + } + else + { + var leftValue = property.GetValue(left); + var rightValue = property.GetValue(right); + + if (!Compare(leftValue, rightValue)) + return false; + } + + return true; + } + private static bool Compare(IEnumerable left, IEnumerable right) + { + var leftEn = left.GetEnumerator(); + var rightEn = right.GetEnumerator(); + + while (leftEn.MoveNext()) + { + if (!rightEn.MoveNext()) + return false; + + if (!Compare(leftEn.Current, rightEn.Current)) + return false; + } + + if (rightEn.MoveNext()) + return false; + + return true; + } + } +} diff --git a/Connected.Interop/Reflection/StringConcatenator.cs b/Connected.Interop/Reflection/StringConcatenator.cs new file mode 100644 index 0000000..51be2a4 --- /dev/null +++ b/Connected.Interop/Reflection/StringConcatenator.cs @@ -0,0 +1,74 @@ +using System.Text; + +namespace Connected.Interop.Reflection +{ + internal class StringConcatenator : IStringConcatenator + { + private readonly string _separator; + private readonly string _nullValue; + private readonly ConcatenationOptions _concatenationOptions; + + public StringConcatenator() + : this(TypeConverter.DefaultStringSeparator) + { + } + + public StringConcatenator(string separator) + : this(separator, ConcatenationOptions.Default) + { + } + + public StringConcatenator(string separator, ConcatenationOptions concatenationOptions) + : this(separator, TypeConverter.DefaultNullStringValue, concatenationOptions) + { + } + + public StringConcatenator(string separator, string nullValue) + : this(separator, nullValue, ConcatenationOptions.Default) + { + } + + public StringConcatenator(string separator, string nullValue, ConcatenationOptions concatenationOptions) + { + if (separator == null) + throw new ArgumentNullException("separator"); + + _separator = separator; + _nullValue = nullValue; + _concatenationOptions = concatenationOptions; + } + + public string Concatenate(string[] values) + { + var valuesToConcatenate = values.Where(v => !IgnoreValue(v)).Select(value => value ?? _nullValue).ToArray(); + + return ConcatenateCore(valuesToConcatenate); + } + + private bool IgnoreValue(string value) + { + if (value is null && (_concatenationOptions & ConcatenationOptions.IgnoreNull) == ConcatenationOptions.IgnoreNull) + return true; + + if (string.IsNullOrEmpty(value) && (_concatenationOptions & ConcatenationOptions.IgnoreEmpty) == ConcatenationOptions.IgnoreEmpty) + return true; + + return false; + } + + protected virtual string ConcatenateCore(string[] values) + { + var result = new StringBuilder(); + + foreach (string value in values) + { + if (result.Length > 0) + result.Append(_separator); + + result.Append(value); + } + + return result.ToString(); + } + } +} \ No newline at end of file diff --git a/Connected.Interop/Reflection/StringSplitter.cs b/Connected.Interop/Reflection/StringSplitter.cs new file mode 100644 index 0000000..e73e569 --- /dev/null +++ b/Connected.Interop/Reflection/StringSplitter.cs @@ -0,0 +1,22 @@ +namespace Connected.Interop.Reflection +{ + internal class StringSplitter : IStringSplitter + { + private readonly string _separator; + + public StringSplitter() + : this(TypeConverter.DefaultStringSeparator) + { + } + + public StringSplitter(string seperator) + { + _separator = seperator ?? throw new ArgumentNullException("separator"); + } + + public string[] Split(string valueList) + { + return valueList.Split(new[] { _separator }, StringSplitOptions.None); + } + } +} \ No newline at end of file diff --git a/Connected.Interop/Reflection/TypeConverter.cs b/Connected.Interop/Reflection/TypeConverter.cs new file mode 100644 index 0000000..77ddab8 --- /dev/null +++ b/Connected.Interop/Reflection/TypeConverter.cs @@ -0,0 +1,883 @@ +using System.Collections; +using System.ComponentModel; +using System.Globalization; +using System.Reflection; + +namespace Connected.Interop.Reflection +{ + [Flags] + internal enum ConcatenationOptions + { + None = 0, + IgnoreNull = 1, + IgnoreEmpty = 2, + Default = None + } + [Flags] + internal enum ConversionOptions + { + None = 0, + EnhancedTypicalValues = 1, + AllowDefaultValueIfNull = 2, + AllowDefaultValueIfWhitespace = 4, + Default = EnhancedTypicalValues | AllowDefaultValueIfNull | AllowDefaultValueIfWhitespace + } + + internal static class TypeConverter + { + + public class EnumerableStringConversion : EnumerableConversion + { + private bool _ignoreEmptyElements; + private bool _trimStart; + private bool _trimEnd; + private string[] _nullValues = new[] { DefaultNullStringValue }; + + internal EnumerableStringConversion(string valueList, IStringSplitter stringSplitter) + : base(stringSplitter.Split(valueList)) + { + } + + internal EnumerableStringConversion(string valueList, Type destinationType, IStringSplitter stringSplitter) + : base(stringSplitter.Split(valueList), destinationType) + { + } + + public EnumerableStringConversion IgnoringEmptyElements() + { + _ignoreEmptyElements = true; + + return this; + } + + public EnumerableStringConversion TrimmingStartOfElements() + { + _trimStart = true; + + return this; + } + + public EnumerableStringConversion TrimmingEndOfElements() + { + _trimEnd = true; + + return this; + } + + public EnumerableStringConversion WithNullBeing(params string[] nullValues) + { + _nullValues = nullValues; + + return this; + } + + protected override IEnumerable GetValuesToConvert() + { + var valuesToConvert = new List(); + string valueToConvert; + + foreach (string value in base.GetValuesToConvert()) + { + valueToConvert = PreProcessValueToConvert(value); + + if (ValueShouldBeIgnored(valueToConvert)) + continue; + + valuesToConvert.Add(valueToConvert); + } + + return valuesToConvert; + } + + private string PreProcessValueToConvert(string value) + { + var valueToConvert = value; + + if (_trimStart) + valueToConvert = valueToConvert.TrimStart(); + + if (_trimEnd) + valueToConvert = valueToConvert.TrimEnd(); + + return ValueOrNull(valueToConvert); + } + + private bool ValueShouldBeIgnored(string valueToConvert) + { + if (string.IsNullOrEmpty(valueToConvert) && _ignoreEmptyElements) + return true; + + return false; + } + + private string? ValueOrNull(string value) + { + if (_nullValues is null) + return value; + + var result = value; + + if (_nullValues.Contains(value)) + result = null; + + return result; + } + } + public class EnumerableConversion : IEnumerable + { + private readonly IEnumerable _valuesToConvert; + private readonly Type _destinationType = typeof(T); + private CultureInfo mCulture; + private ConversionOptions _conversionOptions = ConversionOptions.Default; + private bool _ignoreNullElements; + private bool _ignoreNonConvertibleElements; + + private CultureInfo Culture => mCulture ?? DefaultCulture; + + internal EnumerableConversion(IEnumerable values, Type destinationType) + : this(values) + { + _destinationType = destinationType; + } + + internal EnumerableConversion(IEnumerable values) + { + _valuesToConvert = values; + } + + public EnumerableConversion UsingCulture(CultureInfo culture) + { + mCulture = culture; + + return this; + } + + public EnumerableConversion UsingConversionOptions(ConversionOptions options) + { + _conversionOptions = options; + + return this; + } + + public EnumerableConversion IgnoringNonConvertibleElements() + { + _ignoreNonConvertibleElements = true; + + return this; + } + + public EnumerableConversion IgnoringNullElements() + { + _ignoreNullElements = true; + + return this; + } + + public bool Try(out IEnumerable result) + { + return TryConvert(out result); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + if (TryConvert(out IEnumerable result, out Exception exception)) + return result.GetEnumerator(); + + throw exception; + } + + private bool TryConvert(out IEnumerable result) + { + return TryConvert(out result, out Exception _); + } + + private bool TryConvert(out IEnumerable? result, out Exception? exception) + { + var convertedValues = new List(); + + foreach (object value in GetValuesToConvert()) + { + if (value is null && _ignoreNullElements) + continue; + + if (!TypeConverter.TryConvert(value, _destinationType, out object convertedValue, Culture, _conversionOptions)) + { + if (_ignoreNonConvertibleElements) + continue; + + result = null; + exception = new InvalidConversionException(value, _destinationType); + + return false; + } + + convertedValues.Add((T)convertedValue); + } + + result = convertedValues; + exception = null; + + return true; + } + + protected virtual IEnumerable GetValuesToConvert() + { + return _valuesToConvert; + } + } + + private const string ImplicitOperatorMethodName = "op_Implicit"; + private const string ExplicitOperatorMethodName = "op_Explicit"; + public static readonly CultureInfo DefaultCulture = CultureInfo.InvariantCulture; + public const string DefaultNullStringValue = ".null."; + public const string DefaultStringSeparator = ";"; + + public static bool TryConvertInvariant(object? value, out T? r) + { + try + { + r = ConvertTo(value, CultureInfo.InvariantCulture, ConversionOptions.AllowDefaultValueIfNull | ConversionOptions.AllowDefaultValueIfWhitespace | ConversionOptions.EnhancedTypicalValues); + + return true; + } + catch + { + + r = default; + + return false; + } + } + + public static bool CanConvertTo(object? value) + { + return TryConvertTo(value, out T _); + } + + public static bool CanConvertTo(object? value, CultureInfo culture) + { + return TryConvertTo(value, out T _, culture); + } + + public static bool CanConvertTo(object? value, ConversionOptions options) + { + return TryConvertTo(value, out T _, options); + } + + public static bool CanConvertTo(object? value, CultureInfo culture, ConversionOptions options) + { + return TryConvertTo(value, out T _, culture, options); + } + + public static bool TryConvertTo(object? value, out T result) + { + return TryConvertTo(value, out result, DefaultCulture); + } + + public static bool TryConvertTo(object? value, out T result, CultureInfo culture) + { + return TryConvertTo(value, out result, culture, ConversionOptions.Default); + } + + public static bool TryConvertTo(object? value, out T result, ConversionOptions options) + { + return TryConvertTo(value, out result, DefaultCulture, options); + } + + public static bool TryConvertTo(object? value, out T? result, CultureInfo culture, ConversionOptions options) + { + if (TryConvert(value, typeof(T), out object tmpResult, culture, options)) + { + result = (T)tmpResult; + return true; + } + result = default; + return false; + } + + public static T ConvertTo(object? value) + { + return ConvertTo(value, DefaultCulture); + } + + public static T ConvertTo(object? value, CultureInfo culture) + { + return ConvertTo(value, culture, ConversionOptions.Default); + } + + public static T ConvertTo(object? value, ConversionOptions options) + { + return ConvertTo(value, DefaultCulture, options); + } + + public static T ConvertTo(object? value, CultureInfo culture, ConversionOptions options) + { + return (T)Convert(value, typeof(T), culture, options); + } + + public static bool CanConvert(object? value, Type destinationType) + { + return TryConvert(value, destinationType, out object _); + } + + public static bool CanConvert(object? value, Type destinationType, CultureInfo culture) + { + return TryConvert(value, destinationType, out object _, culture); + } + + public static bool CanConvert(object? value, Type destinationType, ConversionOptions options) + { + return TryConvert(value, destinationType, out object _, options); + } + + public static bool CanConvert(object? value, Type destinationType, CultureInfo culture, ConversionOptions options) + { + return TryConvert(value, destinationType, out object _, culture, options); + } + + public static bool TryConvert(object? value, Type destinationType, out object? result) + { + return TryConvert(value, destinationType, out result, DefaultCulture); + } + + public static bool TryConvert(object? value, Type destinationType, out object? result, CultureInfo culture) + { + return TryConvert(value, destinationType, out result, culture, ConversionOptions.Default); + } + + public static bool TryConvert(object? value, Type destinationType, out object? result, ConversionOptions options) + { + return TryConvert(value, destinationType, out result, DefaultCulture, options); + } + + public static bool TryConvert(object? value, Type destinationType, out object? result, CultureInfo culture, ConversionOptions options) + { + if (destinationType == typeof(object)) + { + result = value; + return true; + } + + if (ValueRepresentsNull(value)) + return TryConvertFromNull(destinationType, out result, options); + + if (destinationType.IsAssignableFrom(value.GetType())) + { + result = value; + return true; + } + + Type coreDestinationType = IsGenericNullable(destinationType) ? GetUnderlyingType(destinationType) : destinationType; + object? tmpResult = null; + + if (TryConvertCore(value, coreDestinationType, ref tmpResult, culture, options)) + { + result = tmpResult; + return true; + } + + result = null; + + return false; + } + + public static object? Convert(object? value, Type destinationType) + { + return Convert(value, destinationType, DefaultCulture); + } + + public static object? Convert(object? value, Type destinationType, CultureInfo culture) + { + return Convert(value, destinationType, culture, ConversionOptions.Default); + } + + public static object? Convert(object? value, Type destinationType, ConversionOptions options) + { + return Convert(value, destinationType, DefaultCulture, options); + } + + public static object? Convert(object? value, Type destinationType, CultureInfo culture, ConversionOptions options) + { + return TryConvert(value, destinationType, out var result, culture, options) + ? result + : throw new InvalidConversionException(value, destinationType); + } + + private static bool TryConvertSpecialValues(object? value, Type destinationType, ref object result) + { + if (value is char && destinationType == typeof(bool)) + return TryConvertCharToBool((char)value, ref result); + + if (value is string && destinationType == typeof(bool)) + return TryConvertStringToBool((string)value, ref result); + + if (value is bool && destinationType == typeof(char)) + return ConvertBoolToChar((bool)value, out result); + + return false; + } + + private static bool TryConvertCharToBool(char value, ref object result) + { + if ("1JYT".Contains(value.ToString().ToUpper())) + { + result = true; + return true; + } + + if ("0NF".Contains(value.ToString().ToUpper())) + { + result = false; + return true; + } + return false; + } + + private static bool TryConvertStringToBool(string value, ref object result) + { + var trueValues = new List(new[] { "1", "j", "ja", "y", "yes", "true", "t", ".t." }); + + if (trueValues.Contains(value.Trim().ToLower())) + { + result = true; + return true; + } + + var falseValues = new List(new[] { "0", "n", "no", "nein", "false", "f", ".f." }); + + if (falseValues.Contains(value.Trim().ToLower())) + { + result = false; + return true; + } + + return false; + } + + private static bool ConvertBoolToChar(bool value, out object result) + { + result = value ? 'T' : 'F'; + + return true; + } + + private static bool TryConvertFromNull(Type destinationType, out object result, ConversionOptions options) + { + result = GetDefaultValueOfType(destinationType); + + if (result is null) + return true; + + return (options & ConversionOptions.AllowDefaultValueIfNull) == ConversionOptions.AllowDefaultValueIfNull; + } + + private static bool TryConvertCore(object value, Type destinationType, ref object result, CultureInfo culture, ConversionOptions options) + { + if (value.GetType() == destinationType) + { + result = value; + return true; + } + + if (TryConvertByDefaultTypeConverters(value, destinationType, culture, ref result)) + return true; + + if (TryConvertByIConvertibleImplementation(value, destinationType, culture, ref result)) + return true; + + if (TryConvertXPlicit(value, destinationType, ExplicitOperatorMethodName, ref result)) + return true; + + if (TryConvertXPlicit(value, destinationType, ImplicitOperatorMethodName, ref result)) + return true; + + if (TryConvertByIntermediateConversion(value, destinationType, ref result, culture, options)) + return true; + + if (destinationType.IsEnum) + { + if (TryConvertToEnum(value, destinationType, ref result)) + return true; + else if (value == null || string.IsNullOrWhiteSpace(value as string)) + { + if (destinationType.IsEnumDefined(0)) + { + result = Enum.Parse(destinationType, destinationType.GetEnumName(0)); + return true; + } + } + } + else if (destinationType == typeof(TimeSpan)) + { + if (value is DateTime d) + { + result = new TimeSpan(0, d.Hour, d.Minute, d.Second, d.Millisecond); + + return true; + } + else if (value is string) + { + if (DateTime.TryParse(value.ToString(), out DateTime dt)) + { + result = new TimeSpan(0, dt.Hour, dt.Minute, dt.Second, dt.Millisecond); + + return true; + + } + } + } + + if ((options & ConversionOptions.EnhancedTypicalValues) == ConversionOptions.EnhancedTypicalValues) + { + if (TryConvertSpecialValues(value, destinationType, ref result)) + return true; + } + + if ((options & ConversionOptions.AllowDefaultValueIfWhitespace) == ConversionOptions.AllowDefaultValueIfWhitespace) + { + if (value is string) + { + if (IsWhiteSpace((string)value)) + { + result = GetDefaultValueOfType(destinationType); + return true; + } + } + } + + return false; + } + + private static bool TryConvertByDefaultTypeConverters(object? value, Type destinationType, CultureInfo culture, ref object result) + { + System.ComponentModel.TypeConverter converter = TypeDescriptor.GetConverter(destinationType); + + if (converter != null) + { + if (converter.CanConvertFrom(value.GetType())) + { + try + { + result = converter.ConvertFrom(null, culture, value); + return true; + } + catch { } + } + } + + converter = TypeDescriptor.GetConverter(value); + + if (converter != null) + { + if (converter.CanConvertTo(destinationType)) + { + try + { + result = converter.ConvertTo(null, culture, value, destinationType); + return true; + } + catch { } + } + } + + return false; + } + + private static bool TryConvertByIConvertibleImplementation(object value, Type destinationType, IFormatProvider formatProvider, ref object result) + { + if (value is IConvertible convertible) + { + try + { + if (destinationType == typeof(bool)) + { + result = convertible.ToBoolean(formatProvider); + return true; + } + + if (destinationType == typeof(byte)) + { + result = convertible.ToByte(formatProvider); + return true; + } + + if (destinationType == typeof(char)) + { + result = convertible.ToChar(formatProvider); + return true; + } + + if (destinationType == typeof(DateTime)) + { + result = convertible.ToDateTime(formatProvider); + return true; + } + + if (destinationType == typeof(decimal)) + { + result = convertible.ToDecimal(formatProvider); + return true; + } + + if (destinationType == typeof(double)) + { + result = convertible.ToDouble(formatProvider); + return true; + } + + if (destinationType == typeof(short)) + { + result = convertible.ToInt16(formatProvider); + return true; + } + + if (destinationType == typeof(int)) + { + result = convertible.ToInt32(formatProvider); + return true; + } + + if (destinationType == typeof(long)) + { + result = convertible.ToInt64(formatProvider); + return true; + } + + if (destinationType == typeof(sbyte)) + { + result = convertible.ToSByte(formatProvider); + return true; + } + + if (destinationType == typeof(float)) + { + result = convertible.ToSingle(formatProvider); + return true; + } + + if (destinationType == typeof(ushort)) + { + result = convertible.ToUInt16(formatProvider); + return true; + } + + if (destinationType == typeof(uint)) + { + result = convertible.ToUInt32(formatProvider); + return true; + } + + if (destinationType == typeof(ulong)) + { + result = convertible.ToUInt64(formatProvider); + return true; + } + } + catch + { + return false; + } + } + + return false; + } + + private static bool TryConvertXPlicit(object value, Type destinationType, string operatorMethodName, ref object result) + { + if (TryConvertXPlicit(value, value.GetType(), destinationType, operatorMethodName, ref result)) + return true; + + if (TryConvertXPlicit(value, destinationType, destinationType, operatorMethodName, ref result)) + return true; + + return false; + } + + private static bool TryConvertXPlicit(object value, Type invokerType, Type destinationType, string xPlicitMethodName, ref object? result) + { + var methods = invokerType.GetMethods(BindingFlags.Public | BindingFlags.Static); + + foreach (MethodInfo method in methods.Where(m => m.Name == xPlicitMethodName)) + { + if (destinationType.IsAssignableFrom(method.ReturnType)) + { + var parameters = method.GetParameters(); + + if (parameters.Count() == 1 && parameters[0].ParameterType == value.GetType()) + { + try + { + result = method.Invoke(null, new[] { value }); + return true; + } + catch { } + } + } + } + + return false; + } + + private static bool TryConvertByIntermediateConversion(object value, Type destinationType, ref object result, CultureInfo culture, ConversionOptions options) + { + if (value is char && (destinationType == typeof(double) || destinationType == typeof(float))) + return TryConvertCore(System.Convert.ToInt16(value), destinationType, ref result, culture, options); + + if ((value is double || value is float) && destinationType == typeof(char)) + return TryConvertCore(System.Convert.ToInt16(value), destinationType, ref result, culture, options); + + return false; + } + + private static bool TryConvertToEnum(object value, Type destinationType, ref object result) + { + try + { + result = Enum.ToObject(destinationType, value); + return true; + } + catch + { + return false; + } + } + + public static EnumerableConversion ConvertToEnumerable(IEnumerable values) + { + return new EnumerableConversion(values); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList) + { + return ConvertToEnumerable(valueList, new StringSplitter()); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList, string seperator) + { + return ConvertToEnumerable(valueList, new StringSplitter(seperator)); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList, IStringSplitter stringSplitter) + { + return new EnumerableStringConversion(valueList, stringSplitter); + } + + public static EnumerableConversion ConvertToEnumerable(IEnumerable values, Type destinationType) + { + return new EnumerableConversion(values, destinationType); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList, Type destinationType) + { + return ConvertToEnumerable(valueList, destinationType, DefaultStringSeparator); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList, Type destinationType, string seperator) + { + return ConvertToEnumerable(valueList, destinationType, new StringSplitter(seperator)); + } + + public static EnumerableStringConversion ConvertToEnumerable(string valueList, Type destinationType, IStringSplitter stringSplitter) + { + return new EnumerableStringConversion(valueList, destinationType, stringSplitter); + } + + public static string ConvertToStringRepresentation(IEnumerable values) + { + return ConvertToStringRepresentation(values, DefaultCulture, new StringConcatenator()); + } + + public static string ConvertToStringRepresentation(IEnumerable values, string seperator) + { + return ConvertToStringRepresentation(values, DefaultCulture, new StringConcatenator(seperator)); + } + + public static string ConvertToStringRepresentation(IEnumerable values, string seperator, string nullValue) + { + return ConvertToStringRepresentation(values, DefaultCulture, new StringConcatenator(seperator, nullValue)); + } + + public static string ConvertToStringRepresentation(IEnumerable values, CultureInfo culture) + { + return ConvertToStringRepresentation(values, culture, new StringConcatenator()); + } + + public static string ConvertToStringRepresentation(IEnumerable values, IStringConcatenator stringConcatenator) + { + return ConvertToStringRepresentation(values, DefaultCulture, stringConcatenator); + } + + public static string ConvertToStringRepresentation(IEnumerable values, CultureInfo culture, IStringConcatenator stringConcatenator) + { + string[] stringValues = ConvertToEnumerable(values).UsingCulture(culture).ToArray(); + + return stringConcatenator.Concatenate(stringValues); + } + + private static bool ValueRepresentsNull(object value) + { + return value == null || value == DBNull.Value; + } + + private static object? GetDefaultValueOfType(Type type) + { + return type.IsValueType ? Activator.CreateInstance(type) : null; + } + + private static bool IsGenericNullable(Type type) + { + return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>).GetGenericTypeDefinition(); + } + + private static Type? GetUnderlyingType(Type type) + { + return Nullable.GetUnderlyingType(type); + } + + private static bool IsWhiteSpace(string value) + { + for (var i = 0; i < value.Length; i++) + { + if (!char.IsWhiteSpace(value[i])) + return false; + } + + return true; + } + + public static bool IsDefaultValue(object value) + { + if (value is null || value == DBNull.Value) + return true; + + if (value is int i) + return i == 0; + else if (value is byte) + return (byte)value == 0; + else if (value is short) + return (short)value == 0; + else if (value is float) + return (float)value == 0; + else if (value is double) + return (double)value == 0; + else if (value is decimal) + return (decimal)value == 0; + else if (value is long) + return (long)value == 0; + else if (value is string) + return string.IsNullOrWhiteSpace(value as string); + else if (value is DateTime) + return (DateTime)value == DateTime.MinValue; + else if (value is Guid) + return (Guid)value == Guid.Empty; + else + return false; + } + + } +} \ No newline at end of file diff --git a/Connected.Interop/SR.Designer.cs b/Connected.Interop/SR.Designer.cs new file mode 100644 index 0000000..5d78d99 --- /dev/null +++ b/Connected.Interop/SR.Designer.cs @@ -0,0 +1,81 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Interop { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Interop.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to The specified value type contains generic parameters. Default value cannot be retrieved. + /// + internal static string ErrDefaultGeneric { + get { + return ResourceManager.GetString("ErrDefaultGeneric", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Invalid interface specified. + /// + internal static string ErrInvalidInterface { + get { + return ResourceManager.GetString("ErrInvalidInterface", resourceCulture); + } + } + } +} diff --git a/Connected.Interop/SR.resx b/Connected.Interop/SR.resx new file mode 100644 index 0000000..690d617 --- /dev/null +++ b/Connected.Interop/SR.resx @@ -0,0 +1,126 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + The specified value type contains generic parameters. Default value cannot be retrieved + + + Invalid interface specified + + \ No newline at end of file diff --git a/Connected.Interop/Serializer.cs b/Connected.Interop/Serializer.cs new file mode 100644 index 0000000..e4732b3 --- /dev/null +++ b/Connected.Interop/Serializer.cs @@ -0,0 +1,205 @@ +using Connected.Interop.Merging; +using System.Collections; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace Connected.Interop; + +public static class Serializer +{ + private static readonly JsonSerializerOptions _options; + + static Serializer() + { + _options = new JsonSerializerOptions + { + AllowTrailingCommas = true, + IncludeFields = false, + IgnoreReadOnlyFields = false, + IgnoreReadOnlyProperties = true, + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + WriteIndented = true + }; + } + + internal static JsonSerializerOptions SerializerOptions => _options; + + public static async Task Deserialize(string value) + { + using var ms = new MemoryStream(Encoding.UTF8.GetBytes(value)); + + ms.Seek(0, SeekOrigin.Begin); + + return await JsonSerializer.DeserializeAsync(ms, _options); + } + + public static async Task Serialize(object value) + { + using var ms = new MemoryStream(); + using var writer = new Utf8JsonWriter(ms, new JsonWriterOptions { Indented = true, SkipValidation = false }); + + if (value.GetType().IsEnumerable()) + SerializeArray(writer, null, value); + else if (value.GetType().IsTypePrimitive()) + SerializePrimitive(writer, value); + else + SerializeObject(writer, null, value); + + await writer.FlushAsync(); + + ms.Seek(0, SeekOrigin.Begin); + + return Encoding.UTF8.GetString(ms.ToArray()); + } + + private static void SerializeObject(Utf8JsonWriter writer, PropertyInfo property, object value) + { + if (value is null) + return; + + var properties = value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance); + + if (property is not null) + writer.WriteStartObject(property.Name.ToCamelCase()); + else + writer.WriteStartObject(); + + foreach (var p in properties) + SerializeProperty(writer, p, value); + + writer.WriteEndObject(); + } + + private static void SerializeArray(Utf8JsonWriter writer, PropertyInfo property, object value) + { + if (value is null) + return; + + if (property is not null) + writer.WriteStartArray(property.Name.ToCamelCase()); + else + writer.WriteStartArray(); + + if (value is IEnumerable enumerable) + { + var enumerator = enumerable.GetEnumerator(); + + while (enumerator.MoveNext()) + SerializeObject(writer, null, enumerator.Current); + } + + writer.WriteEndArray(); + } + + private static void SerializeProperty(Utf8JsonWriter writer, PropertyInfo property, object value) + { + if (property.PropertyType.IsEnumerable()) + SerializeArray(writer, property, value); + else if (!property.PropertyType.IsTypePrimitive()) + SerializeObject(writer, property, value); + + writer.WritePropertyName(property.Name.ToCamelCase()); + + SerializePrimitive(writer, property.GetValue(value)); + } + + private static void SerializePrimitive(Utf8JsonWriter writer, object? value) + { + if (value is null) + { + writer.WriteNullValue(); + return; + } + + if (value.GetType().IsNumber()) + SerializeNumber(writer, value); + else if (value is string) + writer.WriteStringValue(value.ToString()); + else if (value is DateTime date) + writer.WriteStringValue(date); + else if (value is DateTimeOffset offset) + writer.WriteStringValue(offset); + else if (value is Guid guid) + writer.WriteStringValue(guid); + else if (value is bool @bool) + writer.WriteBooleanValue(@bool); + else if (value is Enum en) + { + var type = Enum.GetUnderlyingType(en.GetType()); + + if (TypeConversion.TryConvert(value, out object? enumNumber, type)) + SerializeNumber(writer, enumNumber); + } + } + + private static void SerializeNumber(Utf8JsonWriter writer, object? value) + { + if (value is decimal @decimal) + writer.WriteNumberValue(@decimal); + else if (value is double @double) + writer.WriteNumberValue(@double); + else if (value is float @float) + writer.WriteNumberValue(@float); + else if (value is int @int) + writer.WriteNumberValue(@int); + else if (value is long @long) + writer.WriteNumberValue(@long); + else if (value is uint @uint) + writer.WriteNumberValue(@uint); + else if (value is ulong @ulong) + writer.WriteNumberValue(@ulong); + else if (value is byte @byte) + writer.WriteNumberValue(@byte); + else if (value is sbyte @sbyte) + writer.WriteNumberValue(@sbyte); + else if (value is short @short) + writer.WriteNumberValue(@short); + else if (value is ushort @ushort) + writer.WriteNumberValue(@ushort); + } + + public static T Merge(T destination, params object[] sources) + { + if (destination is null) + return default; + + if (sources is null || !sources.Any()) + return destination; + + var hasJson = false; + + foreach (var source in sources) + { + if (source is JsonNode) + { + hasJson = true; + break; + } + } + + if (!hasJson) + new ObjectMerger().Merge(destination, sources); + else + { + foreach (var source in sources) + { + if (source is JsonNode node) + new JsonMerger().Merge(destination, node); + else + new ObjectMerger().Merge(destination, source); + } + } + + return destination; + } + + public static T Merge(T destination, string value) + { + var node = Deserialize(value); + + return Merge(destination, node); + } +} diff --git a/Connected.Interop/TypeComparer.cs b/Connected.Interop/TypeComparer.cs new file mode 100644 index 0000000..07b464d --- /dev/null +++ b/Connected.Interop/TypeComparer.cs @@ -0,0 +1,28 @@ +namespace Connected.Interop +{ + public static class TypeComparer + { + public static bool Compare(object? left, object? right) + { + if (left is null && right is null) + return true; + + if (left is null && right is not null) + return false; + + if (left is not null && right is null) + return false; + + if (!TypeConversion.TryConvert(left, out string? leftString)) + return false; + + if (!TypeConversion.TryConvert(right, out string? rightString)) + return false; + + if (Guid.TryParse(leftString, out Guid lg) && Guid.TryParse(rightString, out Guid rg)) + return lg == rg; + + return string.Equals(leftString, rightString, StringComparison.Ordinal); + } + } +} diff --git a/Connected.Interop/TypeConversion.cs b/Connected.Interop/TypeConversion.cs new file mode 100644 index 0000000..7dddef1 --- /dev/null +++ b/Connected.Interop/TypeConversion.cs @@ -0,0 +1,47 @@ +using System.Globalization; +using Connected.Interop.Reflection; + +namespace Connected.Interop; + +public static class TypeConversion +{ + public static bool TryConvertInvariant(object? value, out T? result) + { + return TypeConverter.TryConvertInvariant(value, out result); + } + + public static bool TryConvertInvariant(object? value, out object? result, Type destinationType) + { + return TypeConverter.TryConvert(value, destinationType, out result, CultureInfo.InvariantCulture); + } + + public static bool TryConvert(object? value, out object? result, Type destinationType) + { + return TypeConverter.TryConvert(value, destinationType, out result); + } + + public static bool TryConvert(object? value, out T? result) + { + return TypeConverter.TryConvertTo(value, out result); + } + + public static bool TryConvert(object? value, out T? result, CultureInfo culture) + { + return TypeConverter.TryConvertTo(value, out result, culture); + } + + public static T? Convert(object? value) + { + return TypeConverter.ConvertTo(value); + } + + public static object? Convert(object? value, Type destinationType) + { + return TypeConverter.Convert(value, destinationType); + } + + public static T? Convert(object? value, CultureInfo culture) + { + return TypeConverter.ConvertTo(value, culture); + } +} diff --git a/Connected.Interop/TypeSystem.cs b/Connected.Interop/TypeSystem.cs new file mode 100644 index 0000000..9ba736a --- /dev/null +++ b/Connected.Interop/TypeSystem.cs @@ -0,0 +1,131 @@ +using System.Data; + +namespace Connected.Interop +{ + public static class TypeSystem + { + public static bool IsInteger(Type type) + { + var nnType = type.GetNonNullableType(); + + return GetTypeCode(nnType) switch + { + TypeCode.SByte or TypeCode.Int16 or TypeCode.Int32 or TypeCode.Int64 or TypeCode.Byte or TypeCode.UInt16 or TypeCode.UInt32 or TypeCode.UInt64 => true, + _ => false, + }; + } + + public static TypeCode GetTypeCode(Type type) + { + if (type.IsEnum) + return GetTypeCode(Enum.GetUnderlyingType(type)); + + if (type == typeof(bool)) + return TypeCode.Boolean; + else if (type == typeof(byte)) + return TypeCode.Byte; + else if (type == typeof(sbyte)) + return TypeCode.SByte; + else if (type == typeof(short)) + return TypeCode.Int16; + else if (type == typeof(ushort)) + return TypeCode.UInt16; + else if (type == typeof(int)) + return TypeCode.Int32; + else if (type == typeof(uint)) + return TypeCode.UInt32; + else if (type == typeof(long)) + return TypeCode.Int64; + else if (type == typeof(ulong)) + return TypeCode.UInt64; + else if (type == typeof(float)) + return TypeCode.Single; + else if (type == typeof(double)) + return TypeCode.Double; + else if (type == typeof(decimal)) + return TypeCode.Decimal; + else if (type == typeof(string)) + return TypeCode.String; + else if (type == typeof(char)) + return TypeCode.Char; + else if (type == typeof(DateTime)) + return TypeCode.DateTime; + else + return TypeCode.Object; + } + + public static Type ToType(this DbType type) + { + return type switch + { + DbType.AnsiString or DbType.AnsiStringFixedLength or DbType.String or DbType.StringFixedLength or DbType.Xml => typeof(string), + DbType.Binary or DbType.Object => typeof(object), + DbType.Boolean => typeof(bool), + DbType.Byte => typeof(byte), + DbType.Int16 => typeof(short), + DbType.Int32 => typeof(int), + DbType.SByte => typeof(sbyte), + DbType.UInt16 => typeof(ushort), + DbType.UInt32 => typeof(uint), + DbType.Int64 => typeof(long), + DbType.UInt64 => typeof(ulong), + DbType.Currency or DbType.Decimal => typeof(decimal), + DbType.Double => typeof(double), + DbType.Single => typeof(float), + DbType.VarNumeric => typeof(decimal), + DbType.Date or DbType.DateTime or DbType.DateTime2 or DbType.Time => typeof(DateTime), + DbType.DateTimeOffset => typeof(DateTimeOffset), + DbType.Guid => typeof(Guid), + _ => throw new NotSupportedException(), + }; + } + + public static DbType ToDbType(this Type? type) + { + if (type is null) + return DbType.Object; + + var underlyingType = type; + + if (underlyingType.IsEnum) + underlyingType = Enum.GetUnderlyingType(underlyingType); + + if (underlyingType == typeof(char) || underlyingType == typeof(string)) + return DbType.String; + else if (underlyingType == typeof(byte)) + return DbType.Byte; + else if (underlyingType == typeof(bool)) + return DbType.Boolean; + else if (underlyingType == typeof(DateTime) || underlyingType == typeof(DateTimeOffset)) + return DbType.DateTime2; + else if (underlyingType == typeof(decimal)) + return DbType.Decimal; + else if (underlyingType == typeof(double)) + return DbType.Double; + else if (underlyingType == typeof(Guid)) + return DbType.Guid; + else if (underlyingType == typeof(short)) + return DbType.Int16; + else if (underlyingType == typeof(int)) + return DbType.Int32; + else if (underlyingType == typeof(long)) + return DbType.Int64; + else if (underlyingType == typeof(sbyte)) + return DbType.SByte; + else if (underlyingType == typeof(float)) + return DbType.Single; + else if (underlyingType == typeof(TimeSpan)) + return DbType.Time; + else if (underlyingType == typeof(ushort)) + return DbType.UInt16; + else if (underlyingType == typeof(uint)) + return DbType.UInt32; + else if (underlyingType == typeof(ulong)) + return DbType.UInt64; + else if (underlyingType == typeof(byte[])) + return DbType.Binary; + else + return DbType.String; + } + } +} diff --git a/Connected.Interop/Types.cs b/Connected.Interop/Types.cs new file mode 100644 index 0000000..876fac4 --- /dev/null +++ b/Connected.Interop/Types.cs @@ -0,0 +1,225 @@ +using System.Globalization; +using System.Reflection; + +namespace Connected.Interop; + +public static class Types +{ + private static Func? _getUninitializedObject; + public static bool IsAssignableFrom(this Type type, Type otherType) + { + return type.GetTypeInfo().IsAssignableFrom(otherType.GetTypeInfo()); + } + + public static ConstructorInfo? FindConstructor(this Type type, Type[] parameterTypes) + { + foreach (var constructor in type.GetTypeInfo().DeclaredConstructors) + { + if (Methods.ParametersMatch(constructor.GetParameters(), parameterTypes)) + return constructor; + } + + return default; + } + + private static bool TypesMatch(Type[] a, Type[] b) + { + if (a.Length != b.Length) + return false; + + for (var i = 0; i < a.Length; i++) + { + if (a[i] != b[i]) + return false; + } + + return true; + } + + public static IEnumerable GetInheritedTypeInfos(this Type type) + { + var info = type.GetTypeInfo(); + + yield return info; + + if (info.IsInterface) + { + foreach (var ii in info.ImplementedInterfaces) + { + foreach (var iface in ii.GetInheritedTypeInfos()) + yield return iface; + } + } + else + { + for (var i = info.BaseType?.GetTypeInfo(); i != null; i = i.BaseType?.GetTypeInfo()) + yield return i; + } + } + + public static object GetUninitializedObject(this Type type) + { + if (_getUninitializedObject is null) + { + var a = typeof(System.Runtime.CompilerServices.RuntimeHelpers).GetTypeInfo().Assembly; + var fs = a.DefinedTypes.FirstOrDefault(t => string.Equals(t.FullName, "System.Runtime.Serialization.FormatterServices")); + var guo = fs?.DeclaredMethods.FirstOrDefault(m => m.Name == nameof(GetUninitializedObject)); + + if (guo is null) + throw new NotSupportedException($"The runtime does not support the '{nameof(GetUninitializedObject)}' API."); + + Interlocked.CompareExchange(ref _getUninitializedObject, (Func)guo.CreateDelegate(typeof(Func)), null); + } + + return type.GetUninitializedObject(); + } + + public static bool IsTypePrimitive(this Type type) + { + if (type == null) + return false; + + return type.IsPrimitive + || type == typeof(string) + || type == typeof(decimal) + || type.IsEnum + || type.IsValueType; + } + + public static bool IsNumber(this Type type) + { + return type == typeof(byte) + || type == typeof(sbyte) + || type == typeof(short) + || type == typeof(ushort) + || type == typeof(int) + || type == typeof(uint) + || type == typeof(long) + || type == typeof(ulong) + || type == typeof(float) + || type == typeof(double) + || type == typeof(decimal); + } + + public static string ShortName(this Type type) + { + var r = type.Name; + + if (r.Contains('.')) + r = r.Substring(r.LastIndexOf('.') + 1); + + return r; + } + + public static object? CreateInstance(this Type type) + { + return CreateInstance(type, null); + } + + public static object? CreateInstance(this Type type, object[]? ctorArgs) + { + if (type is null) + return default; + + object? instance; + + if (ctorArgs is null) + instance = CreateInstanceInternal(type); + else + instance = CreateInstanceInternal(type, BindingFlags.CreateInstance, ctorArgs); + + return instance; + } + + public static T? CreateInstance(this Type type) + { + return CreateInstance(type, null); + } + + public static T? CreateInstance(this Type type, object[]? ctorArgs) + { + if (type is null) + return default; + + object? instance = null; + + if (ctorArgs is null) + instance = CreateInstanceInternal(type); + else + instance = CreateInstanceInternal(type, BindingFlags.CreateInstance, ctorArgs); + + if (instance is null) + throw new SysException(nameof(Types), $"{SR.ErrInvalidInterface} ({typeof(T).Name})"); + + if (TypeConversion.TryConvert(instance, out T? result)) + return result; + + return default; + } + + private static object? CreateInstanceInternal(this Type type) + { + if (type.IsTypePrimitive()) + return type.GetDefault(); + + return Activator.CreateInstance(type); + } + + private static object? CreateInstanceInternal(this Type type, BindingFlags bindingFlags, object[] ctorArgs) + { + if (type.IsTypePrimitive()) + return type.GetDefault(); + + return Activator.CreateInstance(type, bindingFlags, null, ctorArgs, CultureInfo.InvariantCulture); + } + + public static object? GetDefault(this Type type) + { + if (type is null || type == typeof(void)) + return default; + + var isNullable = !type.GetTypeInfo().IsValueType || type.IsNullableType(); + + if (isNullable) + return default; + + if (type.ContainsGenericParameters) + throw new SysException(nameof(Types), $"{MethodBase.GetCurrentMethod()} {SR.ErrDefaultGeneric} ({type})"); + + return Activator.CreateInstance(type); + } + + public static Type ResolveInterface(object component, string method, params Type[] parameters) + { + var itfs = component.GetType().GetInterfaces(); + + foreach (var i in itfs) + { + var methods = i.GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(f => string.Equals(f.Name, method, StringComparison.Ordinal)); + + foreach (var m in methods) + { + var pars = m.GetParameters(); + + if (pars.Length != parameters.Length) + continue; + + var parametersMatch = true; + + for (var j = 0; j < parameters.Length; j++) + { + if (pars[j].ParameterType != parameters[j] && !parameters[j].IsAssignableFrom(pars[j].ParameterType)) + { + parametersMatch = false; + break; + } + } + + if (parametersMatch) + return i; + } + } + + return null; + } +} diff --git a/Connected.Middleware/Annotations/MiddlewareAttribute.cs b/Connected.Middleware/Annotations/MiddlewareAttribute.cs new file mode 100644 index 0000000..f8c6e88 --- /dev/null +++ b/Connected.Middleware/Annotations/MiddlewareAttribute.cs @@ -0,0 +1,11 @@ +namespace Connected.Middleware.Annotations; + +[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] +public sealed class MiddlewareAttribute : Attribute +{ + public MiddlewareAttribute(string componentMethod) + { + ComponentMethod = componentMethod; + } + public string? ComponentMethod { get; } +} diff --git a/Connected.Middleware/Connected.Middleware.csproj b/Connected.Middleware/Connected.Middleware.csproj new file mode 100644 index 0000000..15a8e85 --- /dev/null +++ b/Connected.Middleware/Connected.Middleware.csproj @@ -0,0 +1,14 @@ + + + + net7.0 + enable + enable + + + + + + + + diff --git a/Connected.Middleware/IMiddlewareService.cs b/Connected.Middleware/IMiddlewareService.cs new file mode 100644 index 0000000..5711316 --- /dev/null +++ b/Connected.Middleware/IMiddlewareService.cs @@ -0,0 +1,16 @@ +using Connected.ServiceModel; +using System.Collections.Immutable; + +namespace Connected.Middleware; + +public interface IMiddlewareService +{ + Task First() where TMiddleware : IMiddleware; + Task> Query() where TMiddleware : IMiddleware; + Task> Query(ICallerContext? context) where TMiddleware : IMiddleware; + + Task First(Type type); + Task> Query(Type type); + Task> Query(Type type, ICallerContext? context); + +} diff --git a/Connected.Middleware/MiddlewareComponent.cs b/Connected.Middleware/MiddlewareComponent.cs new file mode 100644 index 0000000..08b9459 --- /dev/null +++ b/Connected.Middleware/MiddlewareComponent.cs @@ -0,0 +1,39 @@ +namespace Connected.Middleware; + +public abstract class MiddlewareComponent : IMiddleware, IDisposable +{ + + protected bool IsDisposed { get; private set; } + + public async Task Initialize() + { + await OnInitialize(); + } + + protected virtual async Task OnInitialize() + { + await Task.CompletedTask; + } + + protected virtual void OnDisposing() + { + + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + OnDisposing(); + + IsDisposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Middleware/MiddlewareExtensions.cs b/Connected.Middleware/MiddlewareExtensions.cs new file mode 100644 index 0000000..31d9974 --- /dev/null +++ b/Connected.Middleware/MiddlewareExtensions.cs @@ -0,0 +1,22 @@ +namespace Connected.Middleware +{ + public static class MiddlewareExtensions + { + public static List GetImplementedMiddleware(this Type type) + { + var result = new List(); + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (typeof(IMiddleware).FullName is not string fullname) + continue; + + if (i.GetInterface(fullname) is not null) + result.Add(i); + } + + return result; + } + } +} diff --git a/Connected.Middleware/MiddlewareService.cs b/Connected.Middleware/MiddlewareService.cs new file mode 100644 index 0000000..25539b2 --- /dev/null +++ b/Connected.Middleware/MiddlewareService.cs @@ -0,0 +1,232 @@ +using Connected.Collections; +using Connected.Configuration.Environment; +using Connected.Interop; +using Connected.Middleware.Annotations; +using Connected.ServiceModel; +using System.Collections.Concurrent; +using System.Collections.Immutable; +using System.Reflection; + +namespace Connected.Middleware; + +internal class MiddlewareService : IMiddlewareService +{ + private static readonly object _lock = new(); + + static MiddlewareService() + { + Endpoints = new(); + } + public MiddlewareService(IEnvironmentService environmentService, IContext context) + { + EnvironmentService = environmentService; + Context = context; + + if (!IsInitialized) + { + lock (_lock) + { + if (!IsInitialized) + Initialize(); + } + } + } + + private static bool IsInitialized { get; set; } + private IEnvironmentService EnvironmentService { get; } + private IContext Context { get; } + private static ConcurrentDictionary> Endpoints { get; set; } + + public async Task> Query() where TEndpoint : IMiddleware + { + return await Query(null); + } + + public async Task> Query(Type type) + { + return await Query(type, null); + } + + public async Task> Query(ICallerContext? context) where TEndpoint : IMiddleware + { + var key = typeof(TEndpoint).FullName; + + if (key is null || Endpoints is null) + return ImmutableList.Empty; + + if (!Endpoints.TryGetValue(key, out List? items) || items is null) + return ImmutableList.Empty; + + var result = new List(); + + foreach (var type in items) + { + if (!Validate(context, type)) + continue; + + if (Context.GetService(type) is object service) + result.Add((TEndpoint)service); + } + + result.SortByPriority(); + + foreach (var r in result) + await r.Initialize(); + + return result.ToImmutableList(); + } + + public async Task> Query(Type type, ICallerContext? context) + { + var key = type.FullName; + + if (key is null || Endpoints is null) + return ImmutableList.Empty; + + if (!Endpoints.TryGetValue(key, out List? items) || items is null) + return ImmutableList.Empty; + + var result = new List(); + + foreach (var t in items) + { + if (!Validate(context, t)) + continue; + + if (Context.GetService(type) is IMiddleware service) + result.Add(service); + } + + result.SortByPriority(); + + foreach (var r in result) + await r.Initialize(); + + return result.ToImmutableList(); + } + public async Task First() + where TEndpoint : IMiddleware + { + var key = typeof(TEndpoint).FullName; + + if (key is null || Endpoints is null) + return default; + + if (!Endpoints.TryGetValue(key, out List? items) || items is null) + return default; + + var types = new List(); + + foreach (var type in items) + types.Add(type); + + if (!types.Any()) + return default; + + types.SortByPriority(); + + foreach (var type in types) + { + if (!Validate(null, type)) + continue; + + if (Context.GetService(type) is object service) + { + var r = (TEndpoint)service; + await r.Initialize(); + + return r; + } + } + + return default; + } + + public async Task First(Type type) + { + var key = type.FullName; + + if (key is null || Endpoints is null) + return default; + + if (!Endpoints.TryGetValue(key, out List? items) || items is null) + return default; + + var types = new List(); + + foreach (var t in items) + types.Add(t); + + if (!types.Any()) + return default; + + types.SortByPriority(); + + foreach (var t in types) + { + if (!Validate(null, t)) + continue; + + if (Context.GetService(t) is IMiddleware service) + { + await service.Initialize(); + + return service; + } + } + + return default; + } + + private static bool Validate(ICallerContext? context, Type type) + { + if (context is null) + return true; + + var attributes = type.GetCustomAttributes(); + + foreach (var attribute in attributes) + { + if (attribute.GetType() != typeof(MiddlewareAttribute<>)) + continue; + + var method = attribute.GetType().GetProperty(nameof(MiddlewareAttribute.ComponentMethod)); + + if (method is null) + continue; + + if (!string.Equals(TypeConversion.Convert(method.GetValue(attribute)), context.Method, StringComparison.Ordinal)) + continue; + + var argument = attribute.GetType().GetGenericArguments()[0]; + + if (argument != context.Sender?.GetType()) + continue; + + return true; + } + + return false; + } + + private void Initialize() + { + IsInitialized = true; + + foreach (var endpoint in EnvironmentService.Services.IoCEndpoints) + { + var endpoints = endpoint.GetImplementedMiddleware(); + + foreach (var ep in endpoints) + { + if (string.IsNullOrWhiteSpace(ep.FullName)) + continue; + + if (Endpoints.TryGetValue(ep.FullName, out List? list)) + list.Add(endpoint); + else + Endpoints.TryAdd(ep.FullName, new List { endpoint }); + } + } + } +} diff --git a/Connected.Middleware/MiddlewareStartup.cs b/Connected.Middleware/MiddlewareStartup.cs new file mode 100644 index 0000000..6e713b6 --- /dev/null +++ b/Connected.Middleware/MiddlewareStartup.cs @@ -0,0 +1,14 @@ +using Connected.Annotations; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Middleware; + +internal class MiddlewareStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddScoped(typeof(IMiddlewareService), typeof(MiddlewareService)); + } +} diff --git a/Connected.Net/Connected.Net.csproj b/Connected.Net/Connected.Net.csproj new file mode 100644 index 0000000..a80be5b --- /dev/null +++ b/Connected.Net/Connected.Net.csproj @@ -0,0 +1,38 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Net/Endpoints/IEndpoint.cs b/Connected.Net/Endpoints/IEndpoint.cs new file mode 100644 index 0000000..5a9f98a --- /dev/null +++ b/Connected.Net/Endpoints/IEndpoint.cs @@ -0,0 +1,29 @@ +using Connected.Data; + +namespace Connected.Net.Endpoints; + +/// +/// Represents a single endpoint in the environment. +/// +/// +/// Each instance should be defined as an endpoint for other instances +/// to be able to find it on the network and communicate with it. +/// +public interface IEndpoint : IPrimaryKey +{ + /// + /// The name of the Endpoint. This is descriptive property and has + /// no special meaning in the environment. + /// + string Name { get; init; } + /// + /// The IP address or Url of the instance where it responds to requests. This value + /// should be unique across the environment. + /// + string Address { get; init; } + /// + /// The endpoint status. For dynamic scale outs, endpoint could be defined but disabled when + /// not needed. If scaling is needed this value could automatically go to the . + /// + Status Status { get; init; } +} diff --git a/Connected.Net/Endpoints/IEndpointService.cs b/Connected.Net/Endpoints/IEndpointService.cs new file mode 100644 index 0000000..5a0cca8 --- /dev/null +++ b/Connected.Net/Endpoints/IEndpointService.cs @@ -0,0 +1,31 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.ServiceModel; + +namespace Connected.Net.Endpoints; + +/// +/// This is for entity. +/// +/// +/// The environment consists of one or more instances, each described as an . +/// If more that one exists, it means the environment is scaled out. +/// There should be always at least one defined. +/// +[Service] +public interface IEndpointService +{ + /// + /// This method returns all registered endpoints in the current environment + /// + /// List of entities. + [ServiceMethod(ServiceMethodVerbs.Get)] + Task> Query(); + /// + /// This method returns for the specified id. + /// + /// The id of the endpoint for which the entity will be returned. + /// for the specified id if exists, null otherwise. + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs e); +} diff --git a/Connected.Net/HttpExtensions.cs b/Connected.Net/HttpExtensions.cs new file mode 100644 index 0000000..1f39a63 --- /dev/null +++ b/Connected.Net/HttpExtensions.cs @@ -0,0 +1,163 @@ +using System.Net; +using System.Text; +using System.Text.Json.Nodes; +using Connected.Interop; +using Microsoft.AspNetCore.Http; + +namespace Connected.Net; + +public static class HttpExtensions +{ + private const string RequestArgumentsKey = "TP-REQUEST-ARGUMENTS"; + + public static async Task Deserialize(this HttpRequest request) + { + if (await ReadText(request) is not string text || string.IsNullOrWhiteSpace(text)) + return default; + + if (JsonNode.Parse(text, new JsonNodeOptions { PropertyNameCaseInsensitive = true }) is not JsonNode node) + return default; + + request.HttpContext.SetRequestArguments(node); + + return node; + } + public static async Task Deserialize(this HttpRequest request) + { + if (await ReadText(request) is not string text || string.IsNullOrWhiteSpace(text)) + return default; + + return await Serializer.Deserialize(text); + } + + private static async Task ReadText(HttpRequest request) + { + using var reader = new StreamReader(request.Body, Encoding.UTF8); + + return await reader.ReadToEndAsync(); + } + + public static JsonNode? GetRequestArguments(this HttpContext context) + { + var result = context.Items[RequestArgumentsKey]; + + return result is null ? null : (JsonNode)result; + } + + public static void SetRequestArguments(this HttpContext context, JsonNode arguments) + { + if (context is null) + return; + + context.Items[RequestArgumentsKey] = arguments; + } + + public static async Task Get(this IHttpService factory, string? requestUri, CancellationToken cancellationToken = default) + { + return await HandleResponse(await factory.CreateClient().SendAsync(CreateGetMessage(requestUri), cancellationToken)); + } + + public static async Task Get(this IHttpService factory, string? requestUri, object content, CancellationToken cancellationToken = default) + { + return await HandleResponse(await factory.CreateClient().SendAsync(await CreateGetMessage(requestUri, content), cancellationToken)); + } + + public static async Task Post(this IHttpService factory, string? requestUri, object? content, CancellationToken cancellationToken = default) + { + await HandleResponse(await factory.CreateClient().SendAsync(await CreatePostMessage(requestUri, content), cancellationToken)); + } + public static async Task Post(this IHttpService factory, string? requestUri, object? content, CancellationToken cancellationToken = default) + { + return await HandleResponse(await factory.CreateClient().SendAsync(await CreatePostMessage(requestUri, content), cancellationToken)); + } + private static HttpRequestMessage CreateGetMessage(string? requestUri) + { + return new HttpRequestMessage(HttpMethod.Get, requestUri); + } + + private static async Task CreateGetMessage(string? requestUri, object content) + { + return new HttpRequestMessage(HttpMethod.Get, requestUri) + { + Content = await CreateJsonContent(content) + }; + } + + private static async Task CreatePostMessage(string? requestUri, object? content) + { + return new HttpRequestMessage(HttpMethod.Post, requestUri) + { + Content = await CreateJsonContent(content) + }; + } + private static async Task HandleResponse(HttpResponseMessage response) + { + if (!response.IsSuccessStatusCode) + await HandleResponseException(response); + } + + private static async Task HandleResponse(HttpResponseMessage response) + { + if (!response.IsSuccessStatusCode) + await HandleResponseException(response); + + var content = response.Content.ReadAsStringAsync().Result; + + if (IsNull(content)) + return default; + + return await Serializer.Deserialize(content); + } + + private static async Task HandleResponseException(HttpResponseMessage response) + { + var ex = await ParseException(response.Content); + + if (ex is null) + throw new WebException(response.ReasonPhrase); + + var source = string.Empty; + var message = string.Empty; + + if (ex["source"] is JsonNode sourceNode) + source = sourceNode.GetValue(); + + if (ex["message"] is JsonNode messageNode) + message = messageNode.GetValue(); + + throw new WebException(message) { Source = source }; + } + + private static async Task ParseException(HttpContent responseContent) + { + if (responseContent is null) + return default; + + try + { + var rt = responseContent.ReadAsStringAsync().Result; + + return await Serializer.Deserialize(rt); + } + catch + { + return null; + } + } + private static bool IsNull(string content) + { + return string.Equals(content, "null", StringComparison.OrdinalIgnoreCase) || string.IsNullOrWhiteSpace(content); + } + + private static async Task CreateJsonContent(object? content) + { + if (content is null || Convert.IsDBNull(content)) + return new StringContent(string.Empty); + + if (await Serializer.Serialize(content) is not string c) + return new StringContent(string.Empty, Encoding.UTF8, "application/json"); + + return new StringContent(c, Encoding.UTF8, "application/json"); + } + +} diff --git a/Connected.Net/HttpService.cs b/Connected.Net/HttpService.cs new file mode 100644 index 0000000..cb47e3d --- /dev/null +++ b/Connected.Net/HttpService.cs @@ -0,0 +1,19 @@ +using Connected.Net; + +namespace Connected.Net +{ + internal class HttpService : IHttpService + { + public HttpService(IHttpClientFactory factory) + { + Factory = factory; + } + + private IHttpClientFactory Factory { get; } + + public HttpClient CreateClient() + { + return Factory.CreateClient(); + } + } +} diff --git a/Connected.Net/Hubs/Client.cs b/Connected.Net/Hubs/Client.cs new file mode 100644 index 0000000..1b18491 --- /dev/null +++ b/Connected.Net/Hubs/Client.cs @@ -0,0 +1,52 @@ +using Connected.Interop.Reflection; + +namespace Connected.Net.Hubs; + +public enum MessageClientBehavior +{ + Reliable = 1, + FireForget = 2 +} + +public sealed class Client : IComparable> +{ + public Client(string connection) + { + if (string.IsNullOrWhiteSpace(connection)) + throw new ArgumentException(null, nameof(connection)); + + Connection = connection; + } + + public string Connection { get; set; } + public int User { get; set; } + public TArgs? Arguments { get; set; } + public DateTime RetentionDeadline { get; set; } + public MessageClientBehavior Behavior { get; set; } = MessageClientBehavior.Reliable; + + public int CompareTo(Client? other) + { + if (other is null) + return 1; + + if (User != other.User) + return 1; + + if (Behavior != other.Behavior) + return 1; + + if (Arguments is null && other.Arguments is null) + return 0; + + if (Arguments is null && other.Arguments is not null) + return -1; + + if (Arguments is not null && other.Arguments is null) + return 1; + + if (ObjectComparer.Compare(Arguments, other.Arguments)) + return 0; + + return -1; + } +} diff --git a/Connected.Net/Hubs/ClientMessages.cs b/Connected.Net/Hubs/ClientMessages.cs new file mode 100644 index 0000000..d34019f --- /dev/null +++ b/Connected.Net/Hubs/ClientMessages.cs @@ -0,0 +1,86 @@ +using System.Collections.Concurrent; +using System.Collections.Immutable; +using Connected.Collections; + +namespace Connected.Net.Hubs; + +public sealed class ClientMessages +{ + private readonly ConcurrentDictionary> _clients; + public ClientMessages() + { + _clients = new(StringComparer.OrdinalIgnoreCase); + } + private ConcurrentDictionary> Clients => _clients; + public void Clean() + { + foreach (var client in Clients) + { + client.Value.Scave(); + + if (client.Value.IsEmpty) + Clients.TryRemove(client.Key, out _); + } + } + public ImmutableList> Dequeue() + { + var result = new List>(); + + foreach (var client in Clients) + { + var items = client.Value.Dequeue(); + + if (!items.IsEmpty) + result.AddRange(items); + } + + return result.ToImmutableList(true); + } + + public void Remove(string connectionId) + { + foreach (var client in Clients) + { + client.Value.Remove(connectionId); + + if (client.Value.IsEmpty) + Clients.TryRemove(client.Key, out _); + } + } + public void Remove(string connection, IMessageAcknowledgeArgs e) + { + if (!Clients.TryGetValue(connection, out Messages? items)) + return; + + items.Remove(e.Id); + + if (items.IsEmpty) + Clients.TryRemove(connection, out _); + } + public void Remove(string connection, string key) + { + if (string.IsNullOrEmpty(connection)) + return; + + if (!Clients.TryGetValue(connection, out Messages? items)) + return; + + items.Remove(connection, key); + + if (items.IsEmpty) + Clients.Remove(connection, out _); + } + + public void Add(string client, Message message) + { + if (!Clients.TryGetValue(client, out Messages? items)) + { + items = new Messages(); + + if (!Clients.TryAdd(client, items)) + Clients.TryGetValue(client, out items); + } + + items.Add(message); + } +} diff --git a/Connected.Net/Hubs/Clients.cs b/Connected.Net/Hubs/Clients.cs new file mode 100644 index 0000000..85c5c6a --- /dev/null +++ b/Connected.Net/Hubs/Clients.cs @@ -0,0 +1,50 @@ +using System.Collections.Concurrent; +using System.Collections.Immutable; + +namespace Connected.Net.Hubs; + +public sealed class Clients +{ + public Clients() + { + Items = new(StringComparer.OrdinalIgnoreCase); + } + private ConcurrentDictionary> Items { get; set; } + public void AddOrUpdate(Client client) + { + if (!Items.TryGetValue(client.Connection, out Client? existing)) + { + Items.TryAdd(client.Connection, client); + return; + } + + existing.RetentionDeadline = DateTime.MinValue; + } + + public void Clean() + { + var dead = Items.Where(f => f.Value.RetentionDeadline != DateTime.MinValue && f.Value.RetentionDeadline <= DateTime.UtcNow).ToImmutableList(); + + if (dead.IsEmpty) + return; + + foreach (var client in dead) + Items.TryRemove(client.Key, out _); + } + + public void Remove(string connectionId) + { + if (!Items.TryGetValue(connectionId, out Client? client)) + return; + + if (client.Behavior == MessageClientBehavior.FireForget) + Items.TryRemove(connectionId, out _); + else + client.RetentionDeadline = DateTime.UtcNow.AddMinutes(5); + } + + public ImmutableList> Query() + { + return Items.Values.ToImmutableList(); + } +} diff --git a/Connected.Net/Hubs/IServer.cs b/Connected.Net/Hubs/IServer.cs new file mode 100644 index 0000000..d3a56bd --- /dev/null +++ b/Connected.Net/Hubs/IServer.cs @@ -0,0 +1,15 @@ +using Connected.Net; + +namespace Connected.Net.Hubs +{ + public interface IServer + { + event EventHandler? Received; + ClientMessages Messages { get; } + Clients Clients { get; } + + Task Send(TArgs args); + Task Receive(TArgs args); + Task Acknowledge(string connection, IMessageAcknowledgeArgs args); + } +} diff --git a/Connected.Net/Hubs/Message.cs b/Connected.Net/Hubs/Message.cs new file mode 100644 index 0000000..96f71ae --- /dev/null +++ b/Connected.Net/Hubs/Message.cs @@ -0,0 +1,22 @@ +namespace Connected.Net.Hubs +{ + public sealed class Message + { + private static ulong _identity = 0UL; + + public Message(Client client, TArgs args) + { + Client = client; + Arguments = args; + Id = Interlocked.Increment(ref _identity); + Expire = DateTime.UtcNow.AddMinutes(5); + } + + public Client Client { get; } + public ulong Id { get; } + public string? Key { get; set; } + public TArgs? Arguments { get; } + public DateTime NextVisible { get; set; } = DateTime.UtcNow.AddSeconds(5); + public DateTime Expire { get; } + } +} diff --git a/Connected.Net/Hubs/Messages.cs b/Connected.Net/Hubs/Messages.cs new file mode 100644 index 0000000..cb82153 --- /dev/null +++ b/Connected.Net/Hubs/Messages.cs @@ -0,0 +1,71 @@ +using System.Collections.Immutable; +using Connected.Collections; + +namespace Connected.Net.Hubs; + +internal sealed class Messages +{ + private readonly List> _items; + + public Messages() + { + _items = new List>(); + } + private List> Items => _items; + public bool IsEmpty => !Items.Any(); + public ImmutableList> All() => Items.ToImmutableList(true); + public void Scave() + { + var items = All().Where(f => f.Expire <= DateTime.UtcNow); + + foreach (var item in items) + Items.Remove(item); + } + + public ImmutableList> Dequeue() + { + var items = All().Where(f => f.NextVisible <= DateTime.UtcNow).ToImmutableList(true); + + if (!items.Any()) + return ImmutableList>.Empty; + + foreach (var item in items) + item.NextVisible = item.NextVisible.AddSeconds(5); + + return items; + } + + public void Remove(string connectionId) + { + var items = All().Where(f => string.Equals(f.Client.Connection, connectionId, StringComparison.OrdinalIgnoreCase)); + + foreach (var item in items) + Items.Remove(item); + } + + public void Remove(ulong id) + { + if (All().FirstOrDefault(f => f.Id == id) is Message message) + Items.Remove(message); + } + + public void Remap(string connection) + { + foreach (var item in All()) + item.Client.Connection = connection; + } + + public void Remove(string connection, string key) + { + var obsolete = All().Where(f => string.Equals(f.Key, key, StringComparison.OrdinalIgnoreCase) + && string.Equals(f.Client.Connection, connection, StringComparison.OrdinalIgnoreCase)); + + foreach (var o in obsolete) + Items.Remove(o); + } + + public void Add(Message message) + { + Items.Add(message); + } +} diff --git a/Connected.Net/Hubs/Server.cs b/Connected.Net/Hubs/Server.cs new file mode 100644 index 0000000..3fc5d01 --- /dev/null +++ b/Connected.Net/Hubs/Server.cs @@ -0,0 +1,55 @@ +using Connected.Net.Messaging; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Net.Hubs; + +public abstract class Server : IServer + where THub : Hub +{ + public event EventHandler? Received; + protected Server(IHubContext hub) + { + Messages = new(); + Clients = new(); + + Hub = hub; + } + + public ClientMessages Messages { get; } + + public Clients Clients { get; } + private IHubContext Hub { get; } + + public virtual async Task Send(TArgs args) + { + foreach (var client in Clients.Query()) + { + var message = new Message(client, args); + + Messages.Add(client.Connection, message); + + await Hub.Clients.Client(client.Connection).SendCoreAsync("Notify", new object[] { new MessageAcknowledgeArgs(message.Id), args }); + } + } + + public virtual async Task Receive(TArgs args) + { + Received?.Invoke(this, args); + + await Task.CompletedTask; + } + + public async Task Acknowledge(string connection, IMessageAcknowledgeArgs args) + { + try + { + Messages.Remove(connection, args); + + await Task.CompletedTask; + } + catch (Exception ex) + { + await Hub.Clients.AllExcept(connection).SendCoreAsync("Exception", new object[] { new ServerExceptionArgs { Message = ex.Message } }); + } + } +} diff --git a/Connected.Net/Hubs/ServerExceptionArgs.cs b/Connected.Net/Hubs/ServerExceptionArgs.cs new file mode 100644 index 0000000..3a1d3ce --- /dev/null +++ b/Connected.Net/Hubs/ServerExceptionArgs.cs @@ -0,0 +1,7 @@ +namespace Connected.Net.Hubs +{ + public sealed class ServerExceptionArgs : EventArgs + { + public string? Message { get; set; } + } +} diff --git a/Connected.Net/Hubs/ServerWorker.cs b/Connected.Net/Hubs/ServerWorker.cs new file mode 100644 index 0000000..3607091 --- /dev/null +++ b/Connected.Net/Hubs/ServerWorker.cs @@ -0,0 +1,45 @@ +using Connected.Hosting.Workers; +using Connected.Net.Messaging; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Net.Hubs; + +public abstract class ServerWorker : ScheduledWorker + where THub : Hub +{ + protected ServerWorker(IServer server, IHubContext hub) + { + Server = server; + Hub = hub; + Timer = TimeSpan.FromMilliseconds(500); + } + + protected IServer Server { get; } + private IHubContext Hub { get; } + + protected override async Task OnInvoke(CancellationToken cancellationToken) + { + await Send(cancellationToken); + /* + * Clean up every 15 seconds + */ + if (Count % 30 == 0) + await Clean(cancellationToken); + } + + private async Task Send(CancellationToken cancellationToken) + { + var messages = Server.Messages.Dequeue(); + + foreach (var item in messages) + await Hub.Clients.Client(item.Client.Connection).SendCoreAsync("Notify", new object[] { new MessageAcknowledgeArgs(item.Id), item.Arguments }, cancellationToken); + } + + private async Task Clean(CancellationToken cancellationToken) + { + Server.Messages.Clean(); + Server.Clients.Clean(); + + await Task.CompletedTask; + } +} diff --git a/Connected.Net/Hubs/StatefulHub.cs b/Connected.Net/Hubs/StatefulHub.cs new file mode 100644 index 0000000..b818867 --- /dev/null +++ b/Connected.Net/Hubs/StatefulHub.cs @@ -0,0 +1,49 @@ +using Connected.Net.Messaging; +using Connected.Security.Identity; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Net.Hubs; + +public abstract class StatefulHub : Hub> +{ + protected StatefulHub(IServer server) + { + Server = server; + } + + protected IServer Server { get; } + /// + /// This method is called by the client connection + /// + /// + /// + public async Task Notify(TArgs args) + { + await Server.Receive(args); + } + /// + /// This method is called by the client connection + /// + /// + /// + public async Task Acknowledge(MessageAcknowledgeArgs args) + { + await Server.Acknowledge(Context.ConnectionId, args); + } + + public override async Task OnConnectedAsync() + { + var user = 0; + + if (Context.User?.Identity is UserIdentity identity && identity.User is not null) + user = identity.User.Id; + + Server.Clients.AddOrUpdate(new Client(Context.ConnectionId) + { + User = user, + Behavior = MessageClientBehavior.Reliable + }); + + await base.OnConnectedAsync(); + } +} diff --git a/Connected.Net/Messaging/IEndpointClient.cs b/Connected.Net/Messaging/IEndpointClient.cs new file mode 100644 index 0000000..5720400 --- /dev/null +++ b/Connected.Net/Messaging/IEndpointClient.cs @@ -0,0 +1,9 @@ +using Connected.Net.Hubs; + +namespace Connected.Net.Messaging; + +public interface IEndpointClient +{ + Task Notify(IMessageAcknowledgeArgs ack, TArgs? args); + Task Exception(ServerExceptionArgs args); +} diff --git a/Connected.Net/Messaging/MessageAcknowledgeArgs.cs b/Connected.Net/Messaging/MessageAcknowledgeArgs.cs new file mode 100644 index 0000000..25aa6eb --- /dev/null +++ b/Connected.Net/Messaging/MessageAcknowledgeArgs.cs @@ -0,0 +1,14 @@ +using Connected.Net; + +namespace Connected.Net.Messaging +{ + public sealed class MessageAcknowledgeArgs : EventArgs, IMessageAcknowledgeArgs + { + public MessageAcknowledgeArgs(ulong id) + { + Id = id; + } + + public ulong Id { get; } + } +} diff --git a/Connected.Net/NetClaims.cs b/Connected.Net/NetClaims.cs new file mode 100644 index 0000000..8940f3a --- /dev/null +++ b/Connected.Net/NetClaims.cs @@ -0,0 +1,7 @@ +namespace Connected.Net +{ + public static class NetClaims + { + public const string NetDiscovery = "Net Discovery"; + } +} diff --git a/Connected.Net/NetStartup.cs b/Connected.Net/NetStartup.cs new file mode 100644 index 0000000..a77d353 --- /dev/null +++ b/Connected.Net/NetStartup.cs @@ -0,0 +1,16 @@ +using Connected.Annotations; +using Connected.Net.Server; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Net; + +internal class NetStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(IEndpointServer), typeof(EndpointServer)); + services.AddSingleton(typeof(IHttpService), typeof(HttpService)); + } +} diff --git a/Connected.Net/Routes.cs b/Connected.Net/Routes.cs new file mode 100644 index 0000000..bfd0ee8 --- /dev/null +++ b/Connected.Net/Routes.cs @@ -0,0 +1,8 @@ +namespace Connected.Net +{ + public static class Routes + { + public const string EndpointsService = "/sys/endpoints"; + public const string Server = "/sys/server"; + } +} diff --git a/Connected.Net/SR.Designer.cs b/Connected.Net/SR.Designer.cs new file mode 100644 index 0000000..0a5df82 --- /dev/null +++ b/Connected.Net/SR.Designer.cs @@ -0,0 +1,108 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Net { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Net.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Cannot resolve Endpoint server.. + /// + internal static string ErrCannotResolveEndpoint { + get { + return ResourceManager.GetString("ErrCannotResolveEndpoint", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Endpoint not found. + /// + internal static string ErrEndpointNull { + get { + return ResourceManager.GetString("ErrEndpointNull", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The instance is already initialized. Calling initialize more than once is not allowed.. + /// + internal static string ErrInitialized { + get { + return ResourceManager.GetString("ErrInitialized", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This method is allowed only on an Endpoint server instance.. + /// + internal static string ErrNotServer { + get { + return ResourceManager.GetString("ErrNotServer", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Instance is not registered. + /// + internal static string ValInstanceNotRegistered { + get { + return ResourceManager.GetString("ValInstanceNotRegistered", resourceCulture); + } + } + } +} diff --git a/Connected.Net/SR.resx b/Connected.Net/SR.resx new file mode 100644 index 0000000..ca71f96 --- /dev/null +++ b/Connected.Net/SR.resx @@ -0,0 +1,135 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cannot resolve Endpoint server. + + + Endpoint not found + + + The instance is already initialized. Calling initialize more than once is not allowed. + + + This method is allowed only on an Endpoint server instance. + + + Instance is not registered + + \ No newline at end of file diff --git a/Connected.Net/Server/EndpointServer.cs b/Connected.Net/Server/EndpointServer.cs new file mode 100644 index 0000000..1d2405a --- /dev/null +++ b/Connected.Net/Server/EndpointServer.cs @@ -0,0 +1,212 @@ +using System.Collections.Immutable; +using Connected.Configuration; +using Connected.Net.Endpoints; + +namespace Connected.Net.Server; + +/// +/// This class handles communication between environment endpoints. +/// +internal sealed class EndpointServer : IEndpointServer +{ + public event EventHandler? Changed; + public event EventHandler Initialized; + private static ServerProposalArgs? _proposal; + + public EndpointServer(IHttpService http, IConfigurationService configurationService) + { + Endpoints = ImmutableList.Empty; + ConfigurationService = configurationService; + Http = http; + + if (ConfigurationService is null) + throw new NullReferenceException(nameof(IConfigurationService)); + + if (Http is null) + throw new NullReferenceException(nameof(IHttpService)); + } + private ImmutableList Endpoints { get; set; } + private IHttpService Http { get; } + private IConfigurationService ConfigurationService { get; } + private EndpointServerDescriptor? Server { get; set; } + private bool IsInitialized { get; set; } + internal static ServerProposalArgs? ProposalArgs => _proposal; + public async Task Initialize(ImmutableList endpoints, CancellationToken cancellationToken) + { + if (IsInitialized) + throw new SysException(this, SR.ErrInitialized); + + Endpoints = endpoints; + + ValidateInstance(); + InitializeProposal(); + + await ResolveServer(cancellationToken); + + IsInitialized = true; + + Initialized?.Invoke(this, EventArgs.Empty); + } + /// + /// Creates instance used when negotiating with + /// other instances for taking a server role. + /// + /// + private void InitializeProposal() + { + // If there are no endpoints defined we probably won't need arguments anyway. + if (Endpoints is null || !Endpoints.Any()) + { + _proposal = new ServerProposalArgs(); + + return; + } + // Create proposal arguments with the id of this endpoint. + _proposal = new ServerProposalArgs + { + Id = Endpoints.First(f => string.Equals(f.Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)).Id + }; + } + /// + /// Validatates this instance against the endpoints configuration. + /// If endpoints configuration does not have any records we are + /// probably (but not necessarily true) in a single instance environment. + /// Note that endpoints table should always contain at least one record. + /// If endpoints contain at least one record that means our address must + /// match with one record in the endpoints configuration. + /// + private void ValidateInstance() + { + if (Endpoints is null || !Endpoints.Any()) + return; + + if (Endpoints.FirstOrDefault(f => string.Equals(f.Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)) is null) + throw new SysException(this, SR.ValInstanceNotRegistered); + } + /// + /// Returns true if this instance is currently the environment's server. + /// + public Task IsServer() => Task.FromResult(Server is null ? true : !Server.IsRemote); + /// + /// Returns the url of the currently active environment server. + /// + public string ServerUrl + { + get + { + if (Server is null) + throw new NullReferenceException(nameof(Server)); + + if (Server.Endpoint is null) + throw new NullReferenceException(nameof(Server.Endpoint)); + + return Server.Endpoint.Address; + } + } + /// + /// This method tries to resolve which process () is the server in the + /// current environment (network). + /// + private async Task ResolveServer(CancellationToken cancellationToken) + { + var protocol = new ServerProtocol(Endpoints, this, ConfigurationService, Http); + var newServer = await protocol.ResolveServer(cancellationToken); + // Server didn't change. + if (string.Equals(newServer?.Address, Server?.Endpoint?.Address, StringComparison.OrdinalIgnoreCase)) + return; + // We have a new server. If we are the server we must notify other endpoints that we are acting as + // an endpoint server. If not, the chosen server will notify us. + Server = new EndpointServerDescriptor(await protocol.ResolveServer(cancellationToken), ConfigurationService); + + if (!await IsServer()) + return; + + await AnnounceServerChange(); + } + /// + /// This method notifies all endpoints that we are acting as a server. + /// + /// + private async Task AnnounceServerChange() + { + if (Endpoints is null || !Endpoints.Any()) + return; + + foreach (var endpoint in Endpoints) + { + // No need to notify ourselves. + if (string.Equals(endpoint.Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)) + continue; + + await Http.Post($"{Routes.EndpointsService}/{nameof(NotifyServerChange)}", ProposalArgs); + } + } + /// + /// Thos method is called from the Environment server notifying us it is + /// acting as a server. + /// + /// The arguments acting as a proof that conditions + /// are met. + /// If invalid endpoints id has been passed. + public async Task NotifyServerChange(ServerProposalArgs args) + { + // We should probaby valdate the arguments here again in case + // some race condition happened and more thar one server chose + // to be the one + + if (Endpoints.FirstOrDefault(f => f.Id == args.Id) is not IEndpoint endpoint) + throw new NullReferenceException($"{SR.ErrEndpointNull} ({args.Id})"); + + if (Server is not null && Server.Endpoint is not null) + { + // Already points to the active server + if (Server.Endpoint.Id == args.Id) + return; + } + + Server = new EndpointServerDescriptor(endpoint, ConfigurationService); + + await Task.CompletedTask; + } + + private async Task ChangeServer(IEndpoint endpoint) + { + // Nothing changed + if (endpoint is null && (Server is null || Server.Endpoint is null)) + return; + + if (endpoint is not null && Server is not null && Server.Endpoint is not null && endpoint.Id == Server.Endpoint.Id) + return; + // Now change the server's endpoint + Server = new EndpointServerDescriptor(endpoint, ConfigurationService); + // Notify this environment about server change. + Changed?.Invoke(this, new ServerChangedArgs + { + Endpoint = endpoint, + IsRemote = Server.IsRemote + }); + } + /// + /// Returns whether proposal from the remote instance is accepted by this instance. + /// + /// Remote instance proposal. + /// true if remote proposed arguments are chosen, false otherwise. + public bool Propose(ServerProposalArgs e) + { + // The algorithm is simple, the oldest arguments win. That means the instance than started first + // is the winner. + if (e.TimeStamp < ProposalArgs.TimeStamp) + return true; + else if (e.TimeStamp > ProposalArgs.TimeStamp) + return false; + // There is a small chance both timestamps are the same. In that case we choose the winner based + // on weight. + if (e.Weight < ProposalArgs.Weight) + return true; + else if (e.Weight > ProposalArgs.Weight) + return false; + // Well, this is really almost impossible but if it happened the system wouldn't be able to decide + // which one to choose so we will play the random game. + return Random.Shared.Next(1, 100) > 50; + } +} diff --git a/Connected.Net/Server/EndpointServerDescriptor.cs b/Connected.Net/Server/EndpointServerDescriptor.cs new file mode 100644 index 0000000..069e94a --- /dev/null +++ b/Connected.Net/Server/EndpointServerDescriptor.cs @@ -0,0 +1,31 @@ +using Connected.Configuration; +using Connected.Net.Endpoints; + +namespace Connected.Net.Server; + +/// +/// Contains information about the currently active endpoint server. There is only one endpoint server at one point in the environment. +/// In the case of a single instance environment the points to this instance. +/// +internal class EndpointServerDescriptor +{ + /// + /// Create a new instance of the + /// + /// The currently server. Could be null if there are no endpoints defined. + /// + public EndpointServerDescriptor(IEndpoint? endpoint, IConfigurationService configurationService) + { + Endpoint = endpoint; + IsRemote = Endpoint is not null && !string.Equals(configurationService.Endpoint.Address, Endpoint.Address, StringComparison.OrdinalIgnoreCase); + } + /// + /// The Endpoint information about the server. This could point to this instance or can be null which also means + /// this instance is an . + /// + public IEndpoint? Endpoint { get; } + /// + /// Returns true if the Endpoint server is not this instance. + /// + public bool IsRemote { get; } +} diff --git a/Connected.Net/Server/IEndpointServer.cs b/Connected.Net/Server/IEndpointServer.cs new file mode 100644 index 0000000..8bb5289 --- /dev/null +++ b/Connected.Net/Server/IEndpointServer.cs @@ -0,0 +1,23 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.Net.Endpoints; + +namespace Connected.Net.Server; + +[Service] +[ServiceUrl(Routes.Server)] +public interface IEndpointServer +{ + event EventHandler? Changed; + event EventHandler Initialized; + + Task Initialize(ImmutableList endpoints, CancellationToken cancellationToken); + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task NotifyServerChange(ServerProposalArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get)] + Task IsServer(); + + string ServerUrl { get; } +} diff --git a/Connected.Net/Server/IServerConnection.cs b/Connected.Net/Server/IServerConnection.cs new file mode 100644 index 0000000..7cdf211 --- /dev/null +++ b/Connected.Net/Server/IServerConnection.cs @@ -0,0 +1,7 @@ +namespace Connected.Net.Server +{ + public interface IServerConnection : IDisposable, IAsyncDisposable + { + Task Notify(string method, T args); + } +} diff --git a/Connected.Net/Server/ServerChangedArgs.cs b/Connected.Net/Server/ServerChangedArgs.cs new file mode 100644 index 0000000..a796eb4 --- /dev/null +++ b/Connected.Net/Server/ServerChangedArgs.cs @@ -0,0 +1,9 @@ +using Connected.Net.Endpoints; + +namespace Connected.Net.Server; + +public class ServerChangedArgs +{ + public IEndpoint? Endpoint { get; init; } + public bool IsRemote { get; init; } +} diff --git a/Connected.Net/Server/ServerConnection.cs b/Connected.Net/Server/ServerConnection.cs new file mode 100644 index 0000000..10270e1 --- /dev/null +++ b/Connected.Net/Server/ServerConnection.cs @@ -0,0 +1,100 @@ +using Microsoft.AspNetCore.SignalR.Client; + +namespace Connected.Net.Server +{ + public abstract class ServerConnection : IDisposable, IAsyncDisposable, IServerConnection + { + private HubConnection? _connection; + protected ServerConnection(IEndpointServer server) + { + if (server is null) + throw new ArgumentException(null, nameof(server)); + + Server = server; + } + + public virtual async Task Initialize(string hubUrl) + { + if (_connection is not null) + { + if (_connection.State != HubConnectionState.Disconnected) + await _connection.StopAsync(); + + await _connection.DisposeAsync(); + } + + _connection = new HubConnectionBuilder() + .WithUrl($"{Server.ServerUrl}/{hubUrl.Trim('/')}") + .WithAutomaticReconnect() + .Build(); + } + + private bool IsDisposed { get; set; } + protected HubConnection Connection => _connection; + + protected IEndpointServer Server { get; } + + public async Task Connect() + { + if (Connection.State != HubConnectionState.Disconnected) + return; + + await Connection.StartAsync(); + } + + public async Task Disconnect() + { + if (Connection is null) + return; + + if (Connection.State == HubConnectionState.Disconnected) + await Connection.StopAsync(); + } + + public async Task Notify(string method, TArgs args) + { + // TODO: Buffer message if connection is not opened and send it later. + if (Connection.State != HubConnectionState.Connected) + return; + + if (Connection.State == HubConnectionState.Connected) + await Connection.SendAsync(method, args); + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + if (_connection is not null) + { + _connection.DisposeAsync() + .GetAwaiter() + .GetResult(); + + _connection = null; + } + } + + IsDisposed = true; + } + } + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public async ValueTask DisposeAsync() + { + if (_connection is not null) + { + await _connection.DisposeAsync(); + _connection = null; + } + + GC.SuppressFinalize(this); + } + } +} diff --git a/Connected.Net/Server/ServerProposalArgs.cs b/Connected.Net/Server/ServerProposalArgs.cs new file mode 100644 index 0000000..9d50e00 --- /dev/null +++ b/Connected.Net/Server/ServerProposalArgs.cs @@ -0,0 +1,15 @@ +using Connected.ServiceModel; + +namespace Connected.Net.Server; + +public sealed class ServerProposalArgs : Dto +{ + public ServerProposalArgs() + { + Weight = Random.Shared.Next(int.MaxValue); + } + + public DateTime TimeStamp { get; init; } = DateTime.UtcNow; + public int Weight { get; init; } + public int Id { get; init; } +} diff --git a/Connected.Net/Server/ServerProtocol.cs b/Connected.Net/Server/ServerProtocol.cs new file mode 100644 index 0000000..23e8aea --- /dev/null +++ b/Connected.Net/Server/ServerProtocol.cs @@ -0,0 +1,130 @@ +using System.Collections.Immutable; +using Connected.Configuration; +using Connected.Net.Endpoints; + +namespace Connected.Net.Server; + +/// +/// This class resolves Endpoints server. There should be one and only one Endpoint server in the environment (network). +/// +/// +/// No instance is preconfigured to be an Endpoint server thus enabling the system to not have a single point of failure. If there is +/// only one instance in the environment that instance acts as an Endopint server as well. As other instances are created/loaded they will +/// connect to this Endpoint (server). If this instance goes down or restarts one of the other instances will take the server role ensuring if +/// we have at least one process active in the environment we have one and only one Endpoint server as well. +/// +internal class ServerProtocol +{ + public ServerProtocol(ImmutableList endpoints, IEndpointServer server, IConfigurationService configurationService, IHttpService http) + { + Endpoints = endpoints; + ConfigurationService = configurationService; + Http = http; + Server = server as EndpointServer; + } + private ImmutableList Endpoints { get; } + private EndpointServer Server { get; } + private IConfigurationService ConfigurationService { get; } + private IHttpService Http { get; } + + public async Task ResolveServer(CancellationToken cancellationToken = default) + { + /* + * If there are no endpoints defined this process is the + * only instance in the environment. + */ + if (Endpoints is null || !Endpoints.Any()) + return null; + /* + * First try to find existing one. + */ + if (await Lookup(Endpoints, cancellationToken) is IEndpoint existing) + return existing; + /* + * Check if we are the only instance in the environment. If so, we are the server of course. + */ + if (Endpoints.Count == 1 && string.Equals(Endpoints[0].Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)) + return Endpoints[0]; + /* + * We are in the scale out environment and things got a bit complicated. + * No server exist is the environment but we have at least two instances. + * We must now negotiate and eventually choose a winner which will act + * as a server. + * If our proposal fails there must be a better candidate to be a server we just need to + * look for it again. + */ + if (await Propose(Endpoints, cancellationToken) is IEndpoint server) + return server; + /* + * No luck we just hope now that other instance was actually chosen to be a server. + * If if don't find it now we are in big trouble because we couldn't agree which one + * are gonna be a server but the platform must have an Endpoint server. + */ + if (await Lookup(Endpoints, cancellationToken) is not IEndpoint endpoint) + throw new SysException(this, SR.ErrCannotResolveEndpoint); + /* + * We are lucky and have a server. + */ + return endpoint; + } + /// + /// This method tries to find an existing server in the environment. + /// + /// + /// The Endpoint which acts as a server. Null if there is no server. + private async Task Lookup(ImmutableList endpoints, CancellationToken cancellationToken) + { + foreach (var endpoint in endpoints) + { + /* + * Don't connect to itself. + */ + if (string.Equals(endpoint.Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)) + continue; + /* + * We are done if we found a endpoint server. + */ + if (await IsServer(endpoint, cancellationToken)) + return endpoint; + } + + return default; + } + /// + /// Tries to connect to the endpoint and if connection is establied successfully + /// finds out if the endpoint is already a server. + /// If so, the negotiation is completed. + /// + private async Task IsServer(IEndpoint endpoint, CancellationToken cancellationToken = default) + { + return await Http.Get($"{endpoint.Address}/{Routes.EndpointsService}/{nameof(Server.IsServer)}", cancellationToken); + } + /// + /// Thos method negotiates with other instances and tries to eventually chooses the one which will act + /// as a server. + /// + /// The chosen endpoint + private async Task Propose(ImmutableList endpoints, CancellationToken cancellationToken = default) + { + /* + * The algorithm is as follows: + * - compare proposal arguments with other instances + * - if our proposal is the oldest and with the largest value we are the server + * - otherwise some other instance will act as a server + */ + foreach (var endpoint in endpoints) + { + if (string.Equals(endpoint.Address, ConfigurationService.Endpoint.Address, StringComparison.OrdinalIgnoreCase)) + continue; + /* + * If at least one endpoint returns false it means it has a better proposal to be a server. + */ + if (!await Http.Post($"{endpoint.Address}/{Routes.EndpointsService}/{nameof(Server.Propose)}", EndpointServer.ProposalArgs, cancellationToken)) + return null; + } + /* + * All endpoints agreed we are the best candidate to be a server. + */ + return endpoints.First(f => string.Equals(f.Address, ConfigurationService.Endpoint.Address)); + } +} diff --git a/Connected.Notifications/Connected.Notifications.csproj b/Connected.Notifications/Connected.Notifications.csproj new file mode 100644 index 0000000..0e559be --- /dev/null +++ b/Connected.Notifications/Connected.Notifications.csproj @@ -0,0 +1,14 @@ + + + + net7.0 + enable + enable + + + + + + + + diff --git a/Connected.Notifications/Events/EventDispatcher.cs b/Connected.Notifications/Events/EventDispatcher.cs new file mode 100644 index 0000000..016581e --- /dev/null +++ b/Connected.Notifications/Events/EventDispatcher.cs @@ -0,0 +1,11 @@ +using Connected.Collections.Concurrent; + +namespace Connected.Notifications.Events; + +internal sealed class EventDispatcher : Dispatcher +{ + public EventDispatcher() : base(128) + { + + } +} diff --git a/Connected.Notifications/Events/EventDispatcherJob.cs b/Connected.Notifications/Events/EventDispatcherJob.cs new file mode 100644 index 0000000..ebe6831 --- /dev/null +++ b/Connected.Notifications/Events/EventDispatcherJob.cs @@ -0,0 +1,25 @@ +using Connected.Collections.Concurrent; + +namespace Connected.Notifications.Events; + +internal sealed class EventDispatcherJob : DispatcherJob +{ + public EventDispatcherJob(IEventService events) + { + Events = events as EventService; + } + + public EventService? Events { get; } + + protected override async Task OnInvoke(EventServiceArgs args, CancellationToken cancellationToken) + { + if (Events is null) + return; + /* + * We have one simple taskto do: trigger event on the event service component. + * Other clients are listening to that event and respond in a different way + * i.e. broadcasting event to the clients + */ + await Events.Trigger(args); + } +} diff --git a/Connected.Notifications/Events/EventListener.cs b/Connected.Notifications/Events/EventListener.cs new file mode 100644 index 0000000..1ca6f56 --- /dev/null +++ b/Connected.Notifications/Events/EventListener.cs @@ -0,0 +1,23 @@ +using Connected.Middleware; +using Connected.ServiceModel; + +namespace Connected.Notifications.Events; +public abstract class EventListener : MiddlewareComponent, IEventListener + where TArgs : IDto +{ + protected TArgs Arguments { get; private set; } = default!; + protected IOperationState Sender { get; private set; } = default!; + + public async Task Invoke(IOperationState sender, TArgs args) + { + Sender = sender; + Arguments = args; + + await OnInvoke(); + } + + protected virtual async Task OnInvoke() + { + await Task.CompletedTask; + } +} diff --git a/Connected.Notifications/Events/EventService.cs b/Connected.Notifications/Events/EventService.cs new file mode 100644 index 0000000..ed05c16 --- /dev/null +++ b/Connected.Notifications/Events/EventService.cs @@ -0,0 +1,149 @@ +using System.Reflection; +using Connected.Interop; +using Connected.Middleware; +using Connected.Net.Server; +using Connected.Notifications.Events.Server; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; + +namespace Connected.Notifications.Events; + +internal class EventService : IEventService, IDisposable +{ + public event ServiceEventHandler? Event; + + public EventService(IEndpointServer endpoints, EventDispatcher dispatcher, EventServer server, EventServerConnection backplaneClient, IContextProvider provider) + { + Dispatcher = dispatcher; + Server = server; + BackplaneClient = backplaneClient; + Provider = provider; + Endpoints = endpoints; + + endpoints.Changed += OnServerChanged; + endpoints.Initialized += OnServerInitialized; + + BackplaneClient.Received += OnReceived; + } + + private IEndpointServer Endpoints { get; } + + private EventDispatcher Dispatcher { get; set; } + private EventServer Server { get; } + private EventServerConnection BackplaneClient { get; } + private IContextProvider Provider { get; } + private bool IsDisposed { get; set; } + + private async void OnServerInitialized(object? sender, EventArgs e) + { + await Initialize(); + } + + public async Task Initialize() + { + await BackplaneClient.Disconnect(); + + try + { + if (!await Endpoints.IsServer()) + { + await BackplaneClient.Initialize(Endpoints.ServerUrl); + await BackplaneClient.Connect(); + } + } + catch + { + // Server probably not initalized yet + } + } + + private async void OnServerChanged(object? sender, ServerChangedArgs e) + { + await Initialize(); + } + + private async void OnReceived(object? sender, EventNotificationArgs e) + { + //TODO: we are going to need some additional info here, like assembly + var serviceType = Type.GetType(e.Service); + + if (serviceType is null) + return; + + //Enqueue(serviceType, e.Event, EventOrigin.Remote, e.Arguments); + + await Task.CompletedTask; + } + + internal async Task Trigger(EventServiceArgs args) + { + Event?.Invoke(args.Sender, args); + + await Server.Send(new EventNotificationArgs + { + Arguments = await Serializer.Serialize(args.Arguments), + Event = args.Event, + Service = args.Service.GetType().Name + }); + + using var context = Provider.Create(); + + var middlewareService = context.GetService(); + var targetMiddleware = typeof(IEventListener<>); + var gt = targetMiddleware.MakeGenericType(args.Arguments.GetType()); + var middleware = await middlewareService.Query(gt, new CallerContext(args.Service, args.Event)); + + foreach (var m in middleware) + { + if (m.GetType().GetMethod(nameof(IEventListener.Invoke), BindingFlags.Public | BindingFlags.Instance, new Type[] { typeof(IOperationState), args.Arguments.GetType() }) is not MethodInfo method) + continue; + + await method.InvokeAsync(m, args.Sender, args.Arguments); + } + + if (context.GetService() is ITransactionContext transactionContext) + await transactionContext.Commit(); + } + + public async Task Enqueue(IOperationState sender, TService service, string @event, TArgs args) + { + await Enqueue(sender, service, @event, EventOrigin.InProcess, args); + } + + public async Task Enqueue(IOperationState sender, TService service, string @event, EventOrigin origin, TArgs args) + { + Dispatcher.Enqueue(new EventServiceArgs + { + Sender = sender, + Service = service.GetType(), + Event = @event, + Origin = origin, + Arguments = args + }); + + await Task.CompletedTask; + } + + protected virtual void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + if (Dispatcher is not null) + { + Dispatcher.Cancel(); + Dispatcher = null; + } + } + + IsDisposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Notifications/Events/EventServiceArgs.cs b/Connected.Notifications/Events/EventServiceArgs.cs new file mode 100644 index 0000000..cc25024 --- /dev/null +++ b/Connected.Notifications/Events/EventServiceArgs.cs @@ -0,0 +1,18 @@ +using Connected.ServiceModel; + +namespace Connected.Notifications.Events; + +public enum EventOrigin +{ + InProcess = 1, + Remote = 2 +} + +public sealed class EventServiceArgs : Dto +{ + public IOperationState? Sender { get; init; } + public Type Service { get; init; } + public string? Event { get; init; } + public EventOrigin Origin { get; init; } + public object? Arguments { get; init; } +} diff --git a/Connected.Notifications/Events/IEventListener.cs b/Connected.Notifications/Events/IEventListener.cs new file mode 100644 index 0000000..e9c0ca4 --- /dev/null +++ b/Connected.Notifications/Events/IEventListener.cs @@ -0,0 +1,9 @@ +using Connected.ServiceModel; + +namespace Connected.Notifications.Events; + +public interface IEventListener : IMiddleware + where TArgs : IDto +{ + Task Invoke(IOperationState sender, TArgs args); +} diff --git a/Connected.Notifications/Events/IEventService.cs b/Connected.Notifications/Events/IEventService.cs new file mode 100644 index 0000000..5338df5 --- /dev/null +++ b/Connected.Notifications/Events/IEventService.cs @@ -0,0 +1,11 @@ +using Connected.ServiceModel; + +namespace Connected.Notifications.Events +{ + public interface IEventService + { + event ServiceEventHandler? Event; + + Task Enqueue(IOperationState sender, TService service, string @event, TArgs args); + } +} diff --git a/Connected.Notifications/Events/Server/EventHub.cs b/Connected.Notifications/Events/Server/EventHub.cs new file mode 100644 index 0000000..1b45736 --- /dev/null +++ b/Connected.Notifications/Events/Server/EventHub.cs @@ -0,0 +1,10 @@ +using Connected.Net.Hubs; + +namespace Connected.Notifications.Events.Server; + +internal class EventHub : StatefulHub +{ + public EventHub(EventServer server) : base(server) + { + } +} diff --git a/Connected.Notifications/Events/Server/EventNotificationArgs.cs b/Connected.Notifications/Events/Server/EventNotificationArgs.cs new file mode 100644 index 0000000..5c1ba51 --- /dev/null +++ b/Connected.Notifications/Events/Server/EventNotificationArgs.cs @@ -0,0 +1,11 @@ +using Connected; + +namespace Connected.Notifications.Events.Server +{ + internal class EventNotificationArgs : IDto + { + public string? Service { get; set; } + public string? Event { get; set; } + public string? Arguments { get; set; } + } +} diff --git a/Connected.Notifications/Events/Server/EventServer.cs b/Connected.Notifications/Events/Server/EventServer.cs new file mode 100644 index 0000000..483f476 --- /dev/null +++ b/Connected.Notifications/Events/Server/EventServer.cs @@ -0,0 +1,11 @@ +using Connected.Net.Hubs; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Notifications.Events.Server; + +internal class EventServer : Server +{ + public EventServer(IHubContext hub) : base(hub) + { + } +} diff --git a/Connected.Notifications/Events/Server/EventServerConnection.cs b/Connected.Notifications/Events/Server/EventServerConnection.cs new file mode 100644 index 0000000..aa4a2e5 --- /dev/null +++ b/Connected.Notifications/Events/Server/EventServerConnection.cs @@ -0,0 +1,36 @@ +using Connected.Net.Hubs; +using Connected.Net.Messaging; +using Connected.Net.Server; +using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.Extensions.Logging; + +namespace Connected.Notifications.Events.Server; + +internal class EventServerConnection : ServerConnection +{ + public event EventHandler? Received; + + public EventServerConnection(IEndpointServer server, ILogger logger) : base(server) + { + Logger = logger; + } + + private ILogger Logger { get; } + + public override async Task Initialize(string hubUrl) + { + await base.Initialize(hubUrl); + + Connection.On("Notify", (a, e) => + { + Connection.InvokeAsync(nameof(EventHub.Acknowledge), a); + + Received?.Invoke(this, e); + }); + + Connection.On("Exception", (e) => + { + Logger.LogError("Caching hub exception: {message}", e.Message); + }); + } +} diff --git a/Connected.Notifications/Events/Server/EventWorker.cs b/Connected.Notifications/Events/Server/EventWorker.cs new file mode 100644 index 0000000..1b17af8 --- /dev/null +++ b/Connected.Notifications/Events/Server/EventWorker.cs @@ -0,0 +1,11 @@ +using Connected.Net.Hubs; +using Microsoft.AspNetCore.SignalR; + +namespace Connected.Notifications.Events.Server; + +internal sealed class EventWorker : ServerWorker +{ + public EventWorker(EventServer server, IHubContext hub) : base(server, hub) + { + } +} diff --git a/Connected.Notifications/NotificationComponent.cs b/Connected.Notifications/NotificationComponent.cs new file mode 100644 index 0000000..ec296c1 --- /dev/null +++ b/Connected.Notifications/NotificationComponent.cs @@ -0,0 +1,15 @@ +namespace Connected.Notifications +{ + public abstract class NotificationComponent + { + protected NotificationComponent() + { + + } + + protected void Notify() + { + + } + } +} diff --git a/Connected.Notifications/NotificationsStartup.cs b/Connected.Notifications/NotificationsStartup.cs new file mode 100644 index 0000000..91a5fc2 --- /dev/null +++ b/Connected.Notifications/NotificationsStartup.cs @@ -0,0 +1,27 @@ +using Connected.Annotations; +using Connected.Notifications.Events; +using Connected.Notifications.Events.Server; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Notifications; + +internal class NotificationsStartup : Startup +{ + public const string EventsHub = "/events"; + + protected override void OnConfigure(WebApplication app) + { + app.MapHub(EventsHub); + } + + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(EventServer)); + services.AddSingleton(typeof(EventServerConnection)); + services.AddSingleton(typeof(IEventService), typeof(EventService)); + services.AddTransient(typeof(EventDispatcher)); + } +} diff --git a/Connected.Rest/Api/ApiFormatter.cs b/Connected.Rest/Api/ApiFormatter.cs new file mode 100644 index 0000000..ceee406 --- /dev/null +++ b/Connected.Rest/Api/ApiFormatter.cs @@ -0,0 +1,29 @@ +using System.Text.Json.Nodes; +using Microsoft.AspNetCore.Http; + +namespace Connected.Rest; + +internal abstract class ApiFormatter +{ + public HttpContext Context { get; set; } + public async Task ParseArguments() + { + return await OnParseArguments(); + } + + protected abstract Task OnParseArguments(); + + public async Task RenderError(int statusCode, string message) + { + await OnRenderError(statusCode, message); + } + + protected abstract Task OnRenderError(int statusCode, string message); + + public async Task RenderResult(object content) + { + await OnRenderResult(content); + } + + protected abstract Task OnRenderResult(object content); +} diff --git a/Connected.Rest/Api/ApiInvokeDescriptor.cs b/Connected.Rest/Api/ApiInvokeDescriptor.cs new file mode 100644 index 0000000..f075503 --- /dev/null +++ b/Connected.Rest/Api/ApiInvokeDescriptor.cs @@ -0,0 +1,13 @@ +using System.Reflection; +using Connected.Annotations; + +namespace Connected.Rest; + +internal class ApiInvokeDescriptor +{ + public Type? Service { get; set; } + public MethodInfo? Method { get; set; } + + public Type[]? Parameters { get; set; } + public ServiceMethodVerbs Verbs { get; set; } = ServiceMethodVerbs.None; +} diff --git a/Connected.Rest/Api/ApiResolutionService.cs b/Connected.Rest/Api/ApiResolutionService.cs new file mode 100644 index 0000000..80f0e6f --- /dev/null +++ b/Connected.Rest/Api/ApiResolutionService.cs @@ -0,0 +1,207 @@ +using System.Collections.Immutable; +using System.Reflection; +using System.Text; +using Connected.Annotations; +using Connected.Configuration.Environment; +using Connected.ServiceModel; +using Connected.Services; +using Microsoft.AspNetCore.Http; + +namespace Connected.Rest; + +internal class ApiResolutionService : IApiResolutionService +{ + private readonly Dictionary> _methods; + private readonly Dictionary> _arguments; + + public ApiResolutionService(IEnvironmentService environmentService) + { + _methods = new(StringComparer.OrdinalIgnoreCase); + _arguments = new(StringComparer.OrdinalIgnoreCase); + + EnvironmentService = environmentService; + + Initialize(); + } + + private Dictionary> Methods => _methods; + private Dictionary> Arguments => _arguments; + private IEnvironmentService EnvironmentService { get; } + + /// + /// This method tries to resolve argument implementation type based on the parameter's type interface. + /// + /// The implementation parameter of the method which declares the argument + /// that implements 's type interface. + public Type? ResolveArgument(ParameterInfo parameter) + { + if (!Arguments.ContainsKey(ArgumentName(parameter.ParameterType))) + return null; + + var items = Arguments[ArgumentName(parameter.ParameterType)]; + /* + * If the interafce has only one implementation the air is clear. + */ + if (items.Count == 1) + return WrapArgument(items[0].Type, parameter); + /* + * We have more than one implementation. We'll try to find the implementation that match + * the assembly of the invoking method. This is the most probable scenario. + */ + foreach (var argument in items) + { + if (argument.Type.Assembly == parameter.ParameterType.Assembly) + return WrapArgument(argument.Type, parameter); + } + /* + * Method's assembly doesn't have an implementation, let's try to look in the + * interface's assembly. + */ + foreach (var argument in items) + { + if (argument.Type.Assembly == parameter.ParameterType.Assembly) + return WrapArgument(argument.Type, parameter); + } + /* + * Nope, there must be some intermediate assembly implementing the argument and it surely must + * be referenced by the method's assembly. + */ + return WrapArgument(items[0].Type, parameter); + } + + private static Type WrapArgument(Type argument, ParameterInfo parameter) + { + if (!argument.IsGenericType) + return argument; + + return argument.MakeGenericType(parameter.ParameterType.GetGenericArguments()); + } + + public ApiInvokeDescriptor? ResolveMethod(HttpContext context) + { + var route = context.Request.Path.Value; + + if (route is null) + return null; + + if (!Methods.TryGetValue(route.ToString(), out List? descriptor)) + return null; + + //TODO: map overloads from arguments + return descriptor[0]; + } + + private void Initialize() + { + foreach (var type in EnvironmentService.Services.Services) + InitializeApiService(type); + + foreach (var type in EnvironmentService.Services.Arguments) + InitializeArgument(type); + } + + private void InitializeArgument(Type type) + { + if (type.GetImplementedArguments() is not List arguments || !arguments.Any()) + return; + + foreach (var argument in arguments) + { + var name = ArgumentName(argument); + + if (Arguments.TryGetValue(name, out _)) + Arguments[name].Add(new ArgumentDescriptor { Type = type }); + else + Arguments.Add(name, new List { new ArgumentDescriptor { Type = type } }); + } + } + + private void InitializeApiService(Type type) + { + if (type.GetImplementedServices() is not List services || !services.Any()) + return; + + foreach (var service in services) + { + var serviceUrl = ResolveServiceUrl(service); + var methods = service.GetMethods(BindingFlags.Public | BindingFlags.Instance); + + foreach (var method in methods) + { + if (method.GetCustomAttribute() is not ServiceMethodAttribute attribute || attribute.Verbs == ServiceMethodVerbs.None) + continue; + + InitializeServiceMethod(serviceUrl, service, method, attribute.Verbs); + } + } + } + + private void InitializeServiceMethod(string serviceUrl, Type serviceType, MethodInfo method, ServiceMethodVerbs verbs) + { + var parameterTypes = new List(); + + foreach (var parameter in method.GetParameters()) + parameterTypes.Add(parameter.ParameterType); + + var targetMethod = serviceType.GetMethod(method.Name, parameterTypes.ToArray()); + var methodUrl = $"{serviceUrl}/{ResolveMethodUrl(targetMethod)}"; + var descriptor = new ApiInvokeDescriptor { Service = serviceType, Method = targetMethod, Parameters = parameterTypes.ToArray(), Verbs = verbs }; + + if (Methods.TryGetValue(methodUrl, out List? items)) + items.Add(descriptor); + else + Methods.Add(methodUrl, new List { descriptor }); + } + + public ImmutableList> QueryRoutes() + { + var result = new List>(); + + foreach (var method in Methods) + { + var verbs = ServiceMethodVerbs.None; + + foreach (var descriptor in method.Value) + verbs |= descriptor.Verbs; + + result.Add(Tuple.Create(method.Key, verbs)); + } + + return result.ToImmutableList(); + } + + private static string ResolveServiceUrl(Type type) + { + if (type.GetCustomAttribute() is ServiceUrlAttribute attribute) + return attribute.Template; + + return $"{PascalNamespace(type.Namespace)}/{type.Name.ToPascalCase()}".Replace('.', '/'); + } + + private static string ResolveMethodUrl(MethodInfo method) + { + if (method.GetCustomAttribute() is ServiceUrlAttribute attribute) + return attribute.Template; + + return method.Name.ToCamelCase(); + } + + private static string ArgumentName(Type argument) + { + return $"{argument.Namespace}.{argument.Name}, {argument.Assembly.FullName}"; + } + + private static string? PascalNamespace(string? @namespace) + { + if (string.IsNullOrEmpty(@namespace)) + return null; + + var tokens = @namespace.Split('.'); + var result = new StringBuilder(); + + foreach (var token in tokens) + result.Append($"{token.ToPascalCase()}."); + + return result.ToString().TrimEnd('.'); + } +} diff --git a/Connected.Rest/Api/ApiServiceRequestDelegate.cs b/Connected.Rest/Api/ApiServiceRequestDelegate.cs new file mode 100644 index 0000000..ab11373 --- /dev/null +++ b/Connected.Rest/Api/ApiServiceRequestDelegate.cs @@ -0,0 +1,319 @@ +using System.Reflection; +using System.Text.Json.Nodes; +using Connected.Annotations; +using Connected.Interop; +using Connected.Interop.Annotations; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; + +namespace Connected.Rest; + +internal sealed class ApiServiceRequestDelegate : IDisposable +{ + private ApiFormatter _formatter = null; + private IContext _context; + + public ApiServiceRequestDelegate(HttpContext httpContext) + { + HttpContext = httpContext; + _context = httpContext.RequestServices.GetService().Create(); + } + + private bool IsDisposed { get; set; } + private HttpContext HttpContext { get; } + private IContext Context => _context; + /// + /// This method invokes the Api Service method with the request parameters. + /// + /// Result ot the Api method or nothing is the methods return type is void. + public async Task InvokeAsync() + { + /* + * First, try to get appropriate method (target) from the resolution service. + * Methods must be defined with interface which have ApiServie attribute + */ + if (Context.GetService()?.ResolveMethod(HttpContext) is not ApiInvokeDescriptor descriptor) + { + await RenderError(StatusCodes.Status404NotFound); + return; + } + /* + * Event if the method is found we must validate if it is defined for the current Http method. + */ + if (!await ValidateVerb(descriptor)) + return; + /* + * Now map request arguments with the method's one. + */ + var arguments = await MapArgumentsAsync(descriptor.Method); + /* + * And instantiate the Scoped service from the DI. + */ + var service = Context.GetService(descriptor.Service); + /* + * Invoking the method with parsed arguments and rendering results with the formatter + * specified in the request content type (probably Json). + */ + var result = await Methods.InvokeAsync(descriptor.Method, service, arguments?.ToArray()); + /* + * Now, commit changes made in the context. + */ + if (Context.GetService() is ITransactionContext transaction) + await transaction.Commit(); + /* + * Send result to the client. + */ + await RenderResult(result); + } + + private async Task RenderError(int statusCode) + { + await Formatter.RenderError(statusCode, null); + } + + private async Task RenderResult(object content) + { + await Formatter.RenderResult(content); + } + + private async Task ValidateVerb(ApiInvokeDescriptor descriptor) + { + if (string.Equals(HttpContext.Request.Method, HttpMethods.Get, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Get) != ServiceMethodVerbs.Get) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + else if (string.Equals(HttpContext.Request.Method, HttpMethods.Post, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Post) != ServiceMethodVerbs.Post) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + else if (string.Equals(HttpContext.Request.Method, HttpMethods.Put, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Put) != ServiceMethodVerbs.Put) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + else if (string.Equals(HttpContext.Request.Method, HttpMethods.Delete, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Delete) != ServiceMethodVerbs.Delete) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + else if (string.Equals(HttpContext.Request.Method, HttpMethods.Patch, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Patch) != ServiceMethodVerbs.Patch) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + else if (string.Equals(HttpContext.Request.Method, HttpMethods.Options, StringComparison.OrdinalIgnoreCase)) + { + if ((descriptor.Verbs & ServiceMethodVerbs.Options) != ServiceMethodVerbs.Options) + { + await RenderError(StatusCodes.Status405MethodNotAllowed); + return false; + } + } + + return true; + } + /// + /// This method maps request arguments to method arguments. + /// + /// The to which arguments will be mapped. + /// List of method's arguments needed to successfully invoke a method. + /// Thrown if a method argument is interface but no is present. + private async Task> MapArgumentsAsync(MethodInfo method) + { + var arguments = new List(); + var requestParams = await ParseArgumentsAsync(); + /* + * Look for all method parameters. Note that this is already an implementation method not the interface one. + */ + foreach (var parameter in method.GetParameters()) + { + /* + * Most Api methods will have only one parameter which inherits from IDto. + */ + if (parameter.ParameterType.GetInterface(typeof(IDto).FullName) is not null && parameter.ParameterType.IsInterface) + { + /* + * If it's an interface type parameter it must be one of the following: + * - if it has an ArgsBindingAttribute<> we will create instance from its definition + * - we'll look into ApiDiscovery and try to match the implementation class automatically + */ + var attribute = parameter.GetCustomAttribute(typeof(ArgsBindingAttribute<>)); + + if (attribute is not null) + { + /* + * There is a type defined in an attribute which we need. + */ + var genericArguments = attribute.GetType().GetGenericArguments(); + var argument = Context.GetService(genericArguments[0]); + /* + * Merge request properties into argument instance. + */ + Serializer.Merge(requestParams, argument); + + arguments.Add(argument); + } + else + { + if (Context.GetService()?.ResolveArgument(parameter) is Type resolvedType) + { + var argument = Context.GetService(resolvedType); + /* + * Merge request properties into argument instance. + */ + Serializer.Merge(argument, requestParams); + + arguments.Add(argument); + } + else + throw new SysException(this, $"{SR.ErrBindingAttributeMissing} ({method.DeclaringType.FullName}.{method.Name})"); + } + } + else + { + /* + * It's not an IDto, we are currently supporting only types from DI. + * We are going to support binding to + */ + if (parameter.ParameterType.IsTypePrimitive()) + arguments.Add(ResolvePrimitiveArgument(parameter, requestParams)); + else if (Context.GetService(parameter.ParameterType) is object argument) + { + Serializer.Merge(argument, requestParams); + arguments.Add(argument); + } + else + { + await RenderError(StatusCodes.Status400BadRequest); + return null; + } + } + } + + return arguments; + } + private static object? ResolvePrimitiveArgument(ParameterInfo parameter, JsonNode? requestParams) + { + if (requestParams is JsonObject jobject) + { + if (jobject.ContainsKey(parameter.Name)) + { + var value = (object)jobject[parameter.Name].AsValue(); + + if (value is not null && TypeConversion.TryConvert(value, out object result, parameter.ParameterType)) + return result; + } + } + else if (requestParams is JsonValue value) + { + var val = (object)value[parameter.Name].AsValue(); + + if (val is not null && TypeConversion.TryConvert(val, out object result, parameter.ParameterType)) + return result; + } + + return null; + } + /// + /// This method parses Request arguments into JsonNode. + /// + /// A JsonNode representing request parameters. + private async Task ParseArgumentsAsync() + { + var method = HttpContext.Request.Method; + /* + * Post, Delete, Put and Patch methods have parameters in the request body, let formatter do the work. + */ + if (method.Equals(HttpMethods.Post, StringComparison.OrdinalIgnoreCase) || method.Equals(HttpMethods.Delete, StringComparison.Ordinal) + || method.Equals(HttpMethods.Put, StringComparison.OrdinalIgnoreCase) || method.Equals(HttpMethods.Patch, StringComparison.OrdinalIgnoreCase)) + return await Formatter.ParseArguments(); + else + { + /* + * For Get, Options and Trace use query string + */ + var r = new JsonObject(); + + foreach (var i in HttpContext.Request.Query.Keys) + r.Add(i, HttpContext.Request.Query[i].ToString()); + + return r; + } + } + + private ApiFormatter Formatter + { + get + { + if (_formatter is null) + { + var contentType = HttpContext.Request.ContentType; + + if (string.IsNullOrWhiteSpace(contentType)) + _formatter = new JsonFormatter(); + else + { + if (contentType.Contains(';')) + contentType = contentType.Split(';')[0].Trim(); + + if (string.Compare(contentType, JsonFormatter.ContentType, true) == 0) + _formatter = new JsonFormatter(); + else if (string.Compare(contentType, FormFormatter.ContentType, true) == 0) + _formatter = new FormFormatter(); + else + { + HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest; + + throw new SysException(this, $"{SR.ErrContentTypeNotSupported} ({contentType})"); + } + } + + _formatter.Context = HttpContext; + } + + return _formatter; + } + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + { + if (_context is not null) + { + _context.Dispose(); + _context = null; + } + } + + IsDisposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Rest/Api/ArgumentDescriptor.cs b/Connected.Rest/Api/ArgumentDescriptor.cs new file mode 100644 index 0000000..6b429d0 --- /dev/null +++ b/Connected.Rest/Api/ArgumentDescriptor.cs @@ -0,0 +1,6 @@ +namespace Connected.Rest; + +internal class ArgumentDescriptor +{ + public Type Type { get; init; } +} diff --git a/Connected.Rest/Api/FormFormatter.cs b/Connected.Rest/Api/FormFormatter.cs new file mode 100644 index 0000000..d24eb1a --- /dev/null +++ b/Connected.Rest/Api/FormFormatter.cs @@ -0,0 +1,75 @@ +using System.Reflection; +using System.Text; +using System.Text.Json.Nodes; +using Connected.Interop; +using Connected.Net; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.WebUtilities; + +namespace Connected.Rest; + +internal class FormFormatter : ApiFormatter +{ + public const string ContentType = "application/x-www-form-urlencoded"; + protected override async Task OnParseArguments() + { + using var reader = new StreamReader(Context.Request.Body, Encoding.UTF8); + var body = await reader.ReadToEndAsync(); + var qs = QueryHelpers.ParseNullableQuery(body); + var result = new JsonObject(); + + foreach (var q in qs) + result.Add(q.Key, JsonValue.Create(q.Value)); + + Context.SetRequestArguments(result); + + return result; + } + + protected override async Task OnRenderResult(object content) + { + if (Context.Response.HasStarted) + { + var qs = new QueryString(); + + if (content is not null) + { + foreach (var property in content.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + if (!property.PropertyType.IsTypePrimitive()) + continue; + + qs.Add(property.Name.ToCamelCase(), property.GetValue(content) as string); + } + } + + var buffer = Encoding.UTF8.GetBytes(qs.ToUriComponent()); + + Context.Response.Clear(); + Context.Response.ContentLength = buffer.Length; + Context.Response.ContentType = ContentType; + Context.Response.StatusCode = StatusCodes.Status200OK; + + await Context.Response.Body.WriteAsync(buffer); + } + + await Context.Response.CompleteAsync(); + } + + protected override async Task OnRenderError(int statusCode, string message) + { + Context.Response.ContentType = ContentType; + Context.Response.StatusCode = statusCode; + + var qs = new QueryString(); + + qs.Add("message", message); + + var buffer = Encoding.UTF8.GetBytes(qs.ToUriComponent()); + + Context.Response.ContentLength = buffer.Length; + + await Context.Response.Body.WriteAsync(buffer); + await Context.Response.CompleteAsync(); + } +} diff --git a/Connected.Rest/Api/IApiResolutionService.cs b/Connected.Rest/Api/IApiResolutionService.cs new file mode 100644 index 0000000..e9caa7a --- /dev/null +++ b/Connected.Rest/Api/IApiResolutionService.cs @@ -0,0 +1,14 @@ +using System.Collections.Immutable; +using System.Reflection; +using Connected.Annotations; +using Microsoft.AspNetCore.Http; + +namespace Connected.Rest; + +internal interface IApiResolutionService +{ + ApiInvokeDescriptor? ResolveMethod(HttpContext context); + Type? ResolveArgument(ParameterInfo parameter); + + ImmutableList> QueryRoutes(); +} diff --git a/Connected.Rest/Api/JsonFormatter.cs b/Connected.Rest/Api/JsonFormatter.cs new file mode 100644 index 0000000..1155125 --- /dev/null +++ b/Connected.Rest/Api/JsonFormatter.cs @@ -0,0 +1,54 @@ +using System.Text; +using System.Text.Json.Nodes; +using Connected.Interop; +using Connected.Net; +using Microsoft.AspNetCore.Http; + +namespace Connected.Rest; + +internal class JsonFormatter : ApiFormatter +{ + public const string ContentType = "application/json"; + + protected override async Task OnParseArguments() + { + return await Context.Request.Deserialize(); + } + + protected override async Task OnRenderError(int statusCode, string message) + { + Context.Response.ContentType = ContentType; + Context.Response.StatusCode = statusCode; + + if (!string.IsNullOrWhiteSpace(message)) + { + var buffer = Encoding.UTF8.GetBytes(await Serializer.Serialize(new + { + Message = message + })); + + Context.Response.ContentLength = buffer.Length; + + await Context.Response.Body.WriteAsync(buffer); + } + + await Context.Response.CompleteAsync(); + } + + protected override async Task OnRenderResult(object content) + { + if (!Context.Response.HasStarted) + { + var buffer = content is null ? Array.Empty() : Encoding.UTF8.GetBytes(await Serializer.Serialize(content)); + + Context.Response.Clear(); + Context.Response.ContentLength = buffer.Length; + Context.Response.ContentType = ContentType; + Context.Response.StatusCode = StatusCodes.Status200OK; + + await Context.Response.Body.WriteAsync(buffer); + } + + await Context.Response.CompleteAsync(); + } +} diff --git a/Connected.Rest/Connected.Rest.csproj b/Connected.Rest/Connected.Rest.csproj new file mode 100644 index 0000000..561d7df --- /dev/null +++ b/Connected.Rest/Connected.Rest.csproj @@ -0,0 +1,37 @@ + + + + net7.0 + enable + enable + $(MSBuildProjectName.Replace(" ", "_")) + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Rest/Middleware/RestStartup.cs b/Connected.Rest/Middleware/RestStartup.cs new file mode 100644 index 0000000..2444f30 --- /dev/null +++ b/Connected.Rest/Middleware/RestStartup.cs @@ -0,0 +1,49 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Runtime)] + +namespace Connected.Rest; + +internal class RestStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(IApiResolutionService), typeof(ApiResolutionService)); + } + + protected override void OnConfigure(WebApplication app) + { + RegisterRoutes(app); + } + + private void RegisterRoutes(WebApplication app) + { + if (app.Services.GetService() is not IApiResolutionService resolution) + return; + + if (resolution.QueryRoutes() is not ImmutableList> routes || !routes.Any()) + return; + + foreach (var route in routes) + { + app.Map(route.Item1, async (httpContext) => + { + try + { + using var handler = new ApiServiceRequestDelegate(httpContext); + + await handler.InvokeAsync(); + + } + catch + { + //TODO: log exception + throw; + } + }); + } + } +} diff --git a/Connected.Rest/SR.Designer.cs b/Connected.Rest/SR.Designer.cs new file mode 100644 index 0000000..2740330 --- /dev/null +++ b/Connected.Rest/SR.Designer.cs @@ -0,0 +1,81 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Rest { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Connected.Rest.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Binding attribute is missing on a method argument. + /// + internal static string ErrBindingAttributeMissing { + get { + return ResourceManager.GetString("ErrBindingAttributeMissing", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Content type not supported. + /// + internal static string ErrContentTypeNotSupported { + get { + return ResourceManager.GetString("ErrContentTypeNotSupported", resourceCulture); + } + } + } +} diff --git a/Connected.Rest/SR.resx b/Connected.Rest/SR.resx new file mode 100644 index 0000000..ae127f6 --- /dev/null +++ b/Connected.Rest/SR.resx @@ -0,0 +1,126 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Binding attribute is missing on a method argument + + + Content type not supported + + \ No newline at end of file diff --git a/Connected.Runtime/Annotations/OrdinalAttribute.cs b/Connected.Runtime/Annotations/OrdinalAttribute.cs new file mode 100644 index 0000000..e9aeac6 --- /dev/null +++ b/Connected.Runtime/Annotations/OrdinalAttribute.cs @@ -0,0 +1,12 @@ +namespace Connected.Annotations +{ + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Class)] + public sealed class OrdinalAttribute : Attribute + { + public OrdinalAttribute(int value) + { + Value = value; + } + public int Value { get; } + } +} diff --git a/Connected.Runtime/Annotations/PriorityAttribute.cs b/Connected.Runtime/Annotations/PriorityAttribute.cs new file mode 100644 index 0000000..fc22864 --- /dev/null +++ b/Connected.Runtime/Annotations/PriorityAttribute.cs @@ -0,0 +1,12 @@ +namespace Connected.Annotations +{ + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Class | AttributeTargets.Interface)] + public sealed class PriorityAttribute : Attribute + { + public PriorityAttribute(int value) + { + Value = value; + } + public int Value { get; } + } +} diff --git a/Connected.Runtime/Annotations/ServiceRegistrationAttribute.cs b/Connected.Runtime/Annotations/ServiceRegistrationAttribute.cs new file mode 100644 index 0000000..0270240 --- /dev/null +++ b/Connected.Runtime/Annotations/ServiceRegistrationAttribute.cs @@ -0,0 +1,28 @@ +namespace Connected.Annotations +{ + public enum ServiceRegistrationMode + { + Auto = 1, + Manual = 2 + } + + public enum ServiceRegistrationScope + { + Singleton = 1, + Scoped = 2, + Transient = 3 + } + + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false)] + public sealed class ServiceRegistrationAttribute : Attribute + { + public ServiceRegistrationAttribute(ServiceRegistrationMode mode, ServiceRegistrationScope scope) + { + Mode = mode; + Scope = scope; + } + + public ServiceRegistrationMode Mode { get; } = ServiceRegistrationMode.Auto; + public ServiceRegistrationScope Scope { get; } = ServiceRegistrationScope.Scoped; + } +} diff --git a/Connected.Runtime/Connected.Runtime.csproj b/Connected.Runtime/Connected.Runtime.csproj new file mode 100644 index 0000000..d946671 --- /dev/null +++ b/Connected.Runtime/Connected.Runtime.csproj @@ -0,0 +1,19 @@ + + + + net7.0 + enable + enable + $(MSBuildProjectName.Replace(" ", "_")) + + + + + + + + + + + + diff --git a/Connected.Runtime/Data/IPopReceipt.cs b/Connected.Runtime/Data/IPopReceipt.cs new file mode 100644 index 0000000..a62dcad --- /dev/null +++ b/Connected.Runtime/Data/IPopReceipt.cs @@ -0,0 +1,29 @@ +namespace Connected.Data; +/// +/// Represents an entity with a conditional visibillity. +/// +/// +/// Some entities require a singleton access which protects +/// them from being processed by multiple clients at a time. This +/// entity serves for such purpose. One example is queue message which +/// must be processed only by a single client. But, on the other hand, +/// a client has only a limited available time to process it successfully. +/// If it's not processed in time, other client gets opportunity to process +/// the message. The isolation is achieved through the PopReceipt property +/// which is updated everytime client dequeues the message. This means other +/// clients can't successfully update (or delete) the message once other +/// clients was granted the access. +/// +public interface IPopReceipt +{ + /// + /// The id of the current scope. The id is available only upon the expiration + /// (NextVisible). + /// + Guid? PopReceipt { get; init; } + /// + /// The date and time the current PopReceipt expires and the access is granted to + /// other client. + /// + DateTime NextVisible { get; init; } +} diff --git a/Connected.Runtime/IStartup.cs b/Connected.Runtime/IStartup.cs new file mode 100644 index 0000000..83edb8a --- /dev/null +++ b/Connected.Runtime/IStartup.cs @@ -0,0 +1,12 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace Connected; + +public interface IStartup +{ + void ConfigureServices(IServiceCollection services); + void Configure(WebApplication app); + Task Initialize(Dictionary args); + Task Start(Dictionary args); +} diff --git a/Connected.Runtime/Middleware/CallerContext.cs b/Connected.Runtime/Middleware/CallerContext.cs new file mode 100644 index 0000000..27f6932 --- /dev/null +++ b/Connected.Runtime/Middleware/CallerContext.cs @@ -0,0 +1,14 @@ +namespace Connected.ServiceModel +{ + public sealed class CallerContext : ICallerContext + { + public CallerContext(object sender, string? method) + { + Sender = sender; + Method = method; + } + + public object Sender { get; } + public string? Method { get; } + } +} diff --git a/Connected.Runtime/Middleware/CancellationContext.cs b/Connected.Runtime/Middleware/CancellationContext.cs new file mode 100644 index 0000000..75051af --- /dev/null +++ b/Connected.Runtime/Middleware/CancellationContext.cs @@ -0,0 +1,14 @@ +namespace Connected.ServiceModel +{ + internal class CancellationContext : ICancellationContext + { + public CancellationContext(IContext context) + { + Context = context; + } + + public CancellationToken CancellationToken => Context is null ? CancellationToken.None : Context.CancellationToken; + + private IContext Context { get; } + } +} diff --git a/Connected.Runtime/Middleware/Context.cs b/Connected.Runtime/Middleware/Context.cs new file mode 100644 index 0000000..b979d19 --- /dev/null +++ b/Connected.Runtime/Middleware/Context.cs @@ -0,0 +1,64 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Connected.ServiceModel +{ + internal sealed class Context : IContext, IDisposable + { + private CancellationTokenSource _cancellationTokenSource; + + public Context() + { + _cancellationTokenSource = new(); + } + + public CancellationToken CancellationToken => _cancellationTokenSource.Token; + internal IServiceScope? Scope { get; set; } + private bool IsDisposed { get; set; } + + public T? GetService() + { + if (Scope is null) + return default; + + try + { + return Scope.ServiceProvider.GetService(); + } + catch (ObjectDisposedException) + { + return default; + } + } + + public object? GetService(Type serviceType) + { + if (Scope is null) + return default; + + return Scope.ServiceProvider.GetService(serviceType); + } + private void Dispose(bool disposing) + { + if (IsDisposed) + return; + + if (disposing) + { + if (!_cancellationTokenSource.IsCancellationRequested) + _cancellationTokenSource.Cancel(false); + + Scope?.Dispose(); + + _cancellationTokenSource.Dispose(); + } + + IsDisposed = true; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + } +} diff --git a/Connected.Runtime/Middleware/ContextProvider.cs b/Connected.Runtime/Middleware/ContextProvider.cs new file mode 100644 index 0000000..9438cf2 --- /dev/null +++ b/Connected.Runtime/Middleware/ContextProvider.cs @@ -0,0 +1,17 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Connected.ServiceModel +{ + internal class ContextProvider : IContextProvider + { + public IContext Create() + { + var scope = RuntimeStartup.Application.Services.CreateScope(); + var ctx = scope.ServiceProvider.GetService() as Context; + + ctx.Scope = scope; + + return ctx; + } + } +} diff --git a/Connected.Runtime/Middleware/ICallerContext.cs b/Connected.Runtime/Middleware/ICallerContext.cs new file mode 100644 index 0000000..66973e5 --- /dev/null +++ b/Connected.Runtime/Middleware/ICallerContext.cs @@ -0,0 +1,8 @@ +namespace Connected.ServiceModel +{ + public interface ICallerContext + { + object? Sender { get; } + string? Method { get; } + } +} diff --git a/Connected.Runtime/Middleware/ICancellationContext.cs b/Connected.Runtime/Middleware/ICancellationContext.cs new file mode 100644 index 0000000..63a656a --- /dev/null +++ b/Connected.Runtime/Middleware/ICancellationContext.cs @@ -0,0 +1,7 @@ +namespace Connected.ServiceModel +{ + public interface ICancellationContext + { + CancellationToken CancellationToken { get; } + } +} diff --git a/Connected.Runtime/Middleware/IContext.cs b/Connected.Runtime/Middleware/IContext.cs new file mode 100644 index 0000000..1351cce --- /dev/null +++ b/Connected.Runtime/Middleware/IContext.cs @@ -0,0 +1,9 @@ +namespace Connected.ServiceModel +{ + public interface IContext : IDisposable + { + CancellationToken CancellationToken { get; } + T? GetService(); + object? GetService(Type serviceType); + } +} diff --git a/Connected.Runtime/Middleware/IContextProvider.cs b/Connected.Runtime/Middleware/IContextProvider.cs new file mode 100644 index 0000000..be44e4f --- /dev/null +++ b/Connected.Runtime/Middleware/IContextProvider.cs @@ -0,0 +1,7 @@ +namespace Connected.ServiceModel +{ + public interface IContextProvider + { + IContext Create(); + } +} diff --git a/Connected.Runtime/RuntimeExtensions.cs b/Connected.Runtime/RuntimeExtensions.cs new file mode 100644 index 0000000..402b60f --- /dev/null +++ b/Connected.Runtime/RuntimeExtensions.cs @@ -0,0 +1,17 @@ +using Microsoft.AspNetCore.Http; + +namespace Connected; + +public static class RuntimeExtensions +{ + public static bool IsAjaxRequest(this HttpRequest request) + { + if (request is null) + throw new ArgumentNullException(nameof(request)); + + if (request.Headers is not null && request.Headers.ContainsKey("X-Requested-With")) + return string.Equals(request.Headers["X-Requested-With"], "XMLHttpRequest", StringComparison.OrdinalIgnoreCase); + + return false; + } +} diff --git a/Connected.Runtime/RuntimeStartup.cs b/Connected.Runtime/RuntimeStartup.cs new file mode 100644 index 0000000..8478f3c --- /dev/null +++ b/Connected.Runtime/RuntimeStartup.cs @@ -0,0 +1,29 @@ +using Connected.Annotations; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + + +namespace Connected; + +internal sealed class RuntimeStartup : Startup +{ + public static WebApplication? Application { get; private set; } + + protected override void OnConfigure(WebApplication app) + { + Application = app; + } + + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(IContextProvider), typeof(ContextProvider)); + + services.AddScoped(typeof(IContext), typeof(Context)); + services.AddScoped(typeof(ITransactionContext), typeof(TransactionContext)); + services.AddScoped(typeof(ICancellationContext), typeof(CancellationContext)); + } +} diff --git a/Connected.Runtime/ServiceModelExtensions.cs b/Connected.Runtime/ServiceModelExtensions.cs new file mode 100644 index 0000000..ed2340b --- /dev/null +++ b/Connected.Runtime/ServiceModelExtensions.cs @@ -0,0 +1,50 @@ +using Connected; + +namespace Connected.ServiceModel; + +public static class ServiceModelExtensions +{ + public static List GetImplementedArguments(this Type type) + { + /* + * Only direct implementation is used so we can eliminate multiple implementations + * and thus resolving wrong arguments when mapping request. + */ + var interfaces = type.GetInterfaces(); + var allInterfaces = new List(); + var baseInterfaces = new List(); + + foreach (var i in interfaces) + { + if (typeof(IDto)?.FullName is not string fullName) + continue; + + if (i.GetInterface(fullName) is null) + continue; + + if (i == typeof(IDto)) + continue; + + allInterfaces.Add(i); + + foreach (var baseInterface in i.GetInterfaces()) + { + if (baseInterface == typeof(IDto) || typeof(IDto)?.FullName is not string baseFullName) + continue; + + if (baseInterface.GetInterface(baseFullName) is not null) + baseInterfaces.Add(baseInterface); + } + } + + return allInterfaces.Except(baseInterfaces).ToList(); + } + + public static bool IsArgumentImplementation(this Type type) + { + if (typeof(IDto)?.FullName is not string fullName) + return false; + + return !type.IsInterface && !type.IsAbstract && type.GetInterface(fullName) is not null; + } +} diff --git a/Connected.Runtime/Services/ServiceEvents.cs b/Connected.Runtime/Services/ServiceEvents.cs new file mode 100644 index 0000000..80ea684 --- /dev/null +++ b/Connected.Runtime/Services/ServiceEvents.cs @@ -0,0 +1,11 @@ +using Connected.Notifications; + +namespace Connected.ServiceModel +{ + public static class ServiceEvents + { + public const string Inserted = nameof(IServiceNotifications.Inserted); + public const string Updated = nameof(IServiceNotifications.Updated); + public const string Deleted = nameof(IServiceNotifications.Deleted); + } +} diff --git a/Connected.Runtime/Startup.cs b/Connected.Runtime/Startup.cs new file mode 100644 index 0000000..a50b7e6 --- /dev/null +++ b/Connected.Runtime/Startup.cs @@ -0,0 +1,49 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace Connected; + +public abstract class Startup : IStartup +{ + protected IServiceProvider? Services { get; private set; } + + public void Configure(WebApplication app) + { + Services = app.Services; + OnConfigure(app); + } + + protected virtual void OnConfigure(WebApplication app) + { + } + + public void ConfigureServices(IServiceCollection services) + { + OnConfigureServices(services); + } + + protected virtual void OnConfigureServices(IServiceCollection services) + { + + } + + public async Task Initialize(Dictionary args) + { + await OnInitialize(args); + } + + protected virtual async Task OnInitialize(Dictionary args) + { + await Task.CompletedTask; + } + + public async Task Start(Dictionary args) + { + await OnStart(args); + } + + protected virtual async Task OnStart(Dictionary args) + { + await Task.CompletedTask; + } +} diff --git a/Connected.Runtime/Transactions/ITransactionClient.cs b/Connected.Runtime/Transactions/ITransactionClient.cs new file mode 100644 index 0000000..d49d589 --- /dev/null +++ b/Connected.Runtime/Transactions/ITransactionClient.cs @@ -0,0 +1,8 @@ +namespace Connected.ServiceModel.Transactions +{ + public interface ITransactionClient + { + Task Commit(); + Task Rollback(); + } +} diff --git a/Connected.Runtime/Transactions/ITransactionContext.cs b/Connected.Runtime/Transactions/ITransactionContext.cs new file mode 100644 index 0000000..b190061 --- /dev/null +++ b/Connected.Runtime/Transactions/ITransactionContext.cs @@ -0,0 +1,21 @@ +namespace Connected.ServiceModel.Transactions +{ + public enum MiddlewareTransactionState + { + Active = 1, + Committing = 2, + Reverting = 3, + Completed = 4 + } + + public interface ITransactionContext + { + event EventHandler? StateChanged; + MiddlewareTransactionState State { get; } + void Register(ITransactionClient client); + bool IsDirty { get; set; } + + Task Rollback(); + Task Commit(); + } +} diff --git a/Connected.Runtime/Transactions/TransactionContext.cs b/Connected.Runtime/Transactions/TransactionContext.cs new file mode 100644 index 0000000..7057aa8 --- /dev/null +++ b/Connected.Runtime/Transactions/TransactionContext.cs @@ -0,0 +1,62 @@ +using System.Collections.Concurrent; + +namespace Connected.ServiceModel.Transactions +{ + internal class TransactionContext : ITransactionContext + { + private ConcurrentStack _operations; + private MiddlewareTransactionState _state = MiddlewareTransactionState.Active; + + public event EventHandler? StateChanged; + + public MiddlewareTransactionState State + { + get => _state; private set + { + if (_state != value) + { + _state = value; + StateChanged?.Invoke(this, EventArgs.Empty); + } + } + } + + private ConcurrentStack Operations => _operations ??= new ConcurrentStack(); + + public bool IsDirty { get; set; } + + public void Register(ITransactionClient client) + { + if (client is null || Operations.Contains(client)) + return; + + Operations.Push(client); + } + + public async Task Commit() + { + State = MiddlewareTransactionState.Committing; + + while (!Operations.IsEmpty) + { + if (Operations.TryPop(out ITransactionClient? op)) + await op?.Commit(); + } + + State = MiddlewareTransactionState.Completed; + } + + public async Task Rollback() + { + State = MiddlewareTransactionState.Reverting; + + while (!Operations.IsEmpty) + { + if (Operations.TryPop(out ITransactionClient? op)) + await op?.Rollback(); + } + + State = MiddlewareTransactionState.Completed; + } + } +} \ No newline at end of file diff --git a/Connected.Security/Authentication/AuthenticationArgs.cs b/Connected.Security/Authentication/AuthenticationArgs.cs new file mode 100644 index 0000000..4f969ff --- /dev/null +++ b/Connected.Security/Authentication/AuthenticationArgs.cs @@ -0,0 +1,11 @@ +using Connected.ServiceModel; + +namespace Connected.Security.Authentication; + +/// +/// Base interface for authentication providers used when authenticating +/// against identity. +/// +public abstract class AuthenticationArgs : Dto +{ +} diff --git a/Connected.Security/Authentication/AuthenticationResult.cs b/Connected.Security/Authentication/AuthenticationResult.cs new file mode 100644 index 0000000..6c408db --- /dev/null +++ b/Connected.Security/Authentication/AuthenticationResult.cs @@ -0,0 +1,33 @@ +using Connected.Security.Identity; + +namespace Connected.Security.Authentication; + +internal class AuthenticationResult : IAuthenticationResult +{ + public string? Token { get; init; } + + public bool Success { get; init; } + + public AuthenticationResultReason Reason { get; init; } + + public IUser? User { get; init; } + + public static IAuthenticationResult Fail(AuthenticationResultReason reason) + { + return new AuthenticationResult + { + Success = false, + Reason = reason + }; + } + + public static IAuthenticationResult OK(IUser user, string token) + { + return new AuthenticationResult + { + Success = true, + Token = token, + User = user + }; + } +} diff --git a/Connected.Security/Authentication/AuthenticationService.cs b/Connected.Security/Authentication/AuthenticationService.cs new file mode 100644 index 0000000..ee492e2 --- /dev/null +++ b/Connected.Security/Authentication/AuthenticationService.cs @@ -0,0 +1,27 @@ +using Connected.Middleware; +using Connected.Security.Authentication.Middleware; +using Connected.Threading; + +namespace Connected.Security.Authentication; + +/// +/// The implementation of the . +/// +internal sealed class AuthenticationService : IAuthenticationService +{ + public AuthenticationService(IMiddlewareService middleware) + { + Provider = new AsyncLazy(middleware.First()); + } + + private AsyncLazy Provider { get; } + public async Task Authenticate(AuthenticationArgs args) + { + var provider = await Provider.Value; + + if (provider is null) + throw new NullReferenceException(nameof(IAuthenticationMiddleware)); + + return await provider.Authenticate(args); + } +} diff --git a/Connected.Security/Authentication/BasicAuthenticationArgs.cs b/Connected.Security/Authentication/BasicAuthenticationArgs.cs new file mode 100644 index 0000000..ca35ff6 --- /dev/null +++ b/Connected.Security/Authentication/BasicAuthenticationArgs.cs @@ -0,0 +1,23 @@ +using System.ComponentModel.DataAnnotations; + +namespace Connected.Security.Authentication +{ + /// + /// Represents a basic authentication arguments which are based on the user and password. + /// + public sealed class BasicAuthenticationArgs : AuthenticationArgs + { + /// + /// The user's identity. It can contain login name, email or authentication token. + /// + [Required] + [MaxLength(128)] + public string User { get; init; } + /// + /// The user's password. + /// + [Required] + [MaxLength(128)] + public string Password { get; init; } + } +} diff --git a/Connected.Security/Authentication/BearerAuthenticationArgs.cs b/Connected.Security/Authentication/BearerAuthenticationArgs.cs new file mode 100644 index 0000000..707f8f9 --- /dev/null +++ b/Connected.Security/Authentication/BearerAuthenticationArgs.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; + +namespace Connected.Security.Authentication +{ + /// + /// Represents the token based authentication. This is usually used by + /// devices and processes which does not have a direct identity. The identity + /// is then resolved via mapping between the provided token and a user. + /// + public sealed class BearerAuthenticationArgs : AuthenticationArgs + { + /// + /// The bearer token. + /// + [Required] + [MaxLength(128)] + public string Token { get; init; } + } +} diff --git a/Connected.Security/Authentication/IAuthenticationResult.cs b/Connected.Security/Authentication/IAuthenticationResult.cs new file mode 100644 index 0000000..1a1bb4f --- /dev/null +++ b/Connected.Security/Authentication/IAuthenticationResult.cs @@ -0,0 +1,86 @@ +using Connected.Security.Identity; + +namespace Connected.Security.Authentication; + +/// +/// Defines the reason decided +/// to allow or refuse the authentication . +/// +public enum AuthenticationResultReason +{ + /// + /// The authentication was successfully. This is the only reason + /// that is used when authentication is successful. + /// + OK = 0, + /// + /// The provided identity was not found. + /// + NotFound = 1, + /// + /// The provided identity did not have a valid password. + /// + InvalidPassword = 2, + /// + /// The provided identity is not active in the environment. + /// + Inactive = 3, + /// + /// The provided identity is locked or blocked by the environment. + /// + Locked = 4, + /// + /// The provided identity does not have a password set but an + /// requires one. + /// + NoPassword = 5, + /// + /// The provided identity's password has expired. + /// + PasswordExpired = 6, + /// + /// The token provided by the identity is invalid. + /// + InvalidToken = 7, + /// + /// The credentials provided by identity are not valid or are not supported by the environment. + /// + InvalidCredentials = 8, + /// + /// There is other issue regarding identity which cannot be resolved. + /// + Other = 99 +} +/// +/// Represents the result of the authentication process. should never +/// throw an exception during authentication process. It must always return regardless +/// wether it was successful or not. +/// +public interface IAuthenticationResult +{ + /// + /// The token which can be used to uniquely identify the identity. This token is + /// generated by the when the authentication is + /// successful and no previous token was created. + /// + /// + /// Each identity should have only one active token at the time and the new token can be invalidated + /// by the environment. Token is also valid only for a limited time. Once expired, user will need to + /// authenticate again. The primary use of this token is in the SSO systems. + /// + string? Token { get; } + /// + /// Returns true if authentication was successful, false otherwise. + /// + bool Success { get; } + /// + /// The reason authentication was successful or not. + /// + AuthenticationResultReason Reason { get; } + /// + /// The identity which can be used in the process pipeline. + /// + /// + /// For example, this value will be used by HttpRequests as a User property. + IUser? User { get; } +} diff --git a/Connected.Security/Authentication/IAuthenticationService.cs b/Connected.Security/Authentication/IAuthenticationService.cs new file mode 100644 index 0000000..6173040 --- /dev/null +++ b/Connected.Security/Authentication/IAuthenticationService.cs @@ -0,0 +1,31 @@ +namespace Connected.Security.Authentication +{ + /// + /// The service which performs authentication requests. + /// + /// + /// The service uses registered to perform + /// the actual authentication. + /// + public interface IAuthenticationService + { + /// + /// Performs authentication against identity. Identity is not necessarry a user, it can be device, process + /// or any other entity which needs protection. + /// + /// The authentication arguments explaining the identity. The following identities are supported by a + /// default authentication provider: + /// + /// + /// + /// + /// + /// + /// + /// + /// containing information about authentication wether it is successfull or not and + /// the reason why it was not successfully. If it's successfully it also provides identity. + /// + Task Authenticate(AuthenticationArgs args); + } +} diff --git a/Connected.Security/Authentication/IAuthenticationToken.cs b/Connected.Security/Authentication/IAuthenticationToken.cs new file mode 100644 index 0000000..31f9262 --- /dev/null +++ b/Connected.Security/Authentication/IAuthenticationToken.cs @@ -0,0 +1,12 @@ +using Connected.Data; + +namespace Connected.Security.Authentication; + +public interface IAuthenticationToken : IPrimaryKey +{ + string Token { get; init; } + string Identity { get; init; } + DateTime Expiration { get; init; } + Status Status { get; init; } + string Tags { get; init; } +} diff --git a/Connected.Security/Authentication/IAuthenticationTokenService.cs b/Connected.Security/Authentication/IAuthenticationTokenService.cs new file mode 100644 index 0000000..60a8022 --- /dev/null +++ b/Connected.Security/Authentication/IAuthenticationTokenService.cs @@ -0,0 +1,8 @@ +using Connected.Annotations; + +namespace Connected.Security.Authentication; + +[Service] +public interface IAuthenticationTokenService +{ +} diff --git a/Connected.Security/Authentication/Middleware/DefaultAuthenticationMiddleware.cs b/Connected.Security/Authentication/Middleware/DefaultAuthenticationMiddleware.cs new file mode 100644 index 0000000..df61046 --- /dev/null +++ b/Connected.Security/Authentication/Middleware/DefaultAuthenticationMiddleware.cs @@ -0,0 +1,117 @@ +using Connected.Annotations; +using Connected.Configuration; +using Connected.Middleware; +using Connected.Security.Cryptography; +using Connected.Security.Identity; +using Microsoft.IdentityModel.Tokens; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Claims; +using System.Text; + +namespace Connected.Security.Authentication.Middleware; + +/// +/// The default implementation of the . +/// +[Priority(0)] +internal sealed class DefaultAuthenticationMiddleware : MiddlewareComponent, IAuthenticationMiddleware +{ + public DefaultAuthenticationMiddleware(IUserService userService, IRoleService roleService, IConfigurationService configurationService, ICryptographyService cryptographyService) + { + UserService = userService; + RoleService = roleService; + ConfigurationService = configurationService; + CryptographyService = cryptographyService; + } + + private IUserService UserService { get; } + private IRoleService RoleService { get; } + public IConfigurationService ConfigurationService { get; } + public ICryptographyService CryptographyService { get; } + + public async Task Authenticate(AuthenticationArgs args) + { + if (args is BasicAuthenticationArgs basic) + return await AuthenticateBasic(basic); + else if (args is PinAuthenticationArgs pin) + return await AuthenticatePin(pin); + else if (args is BearerAuthenticationArgs bearer) + return await AuthenticateBearer(bearer); + else if (args is SsoAuthenticationArgs sso) + return await AuthenticateSso(sso); + else + throw new NotSupportedException(); + } + + private async Task AuthenticateSso(SsoAuthenticationArgs sso) + { + if (await UserService.Resolve(new UserResolveArgs { Criteria = sso.Token }) is not IUserPassport user) + return AuthenticationResult.Fail(AuthenticationResultReason.InvalidCredentials); + + return AuthenticationResult.OK(user, new JwtSecurityTokenHandler().WriteToken(CreateToken(user))); + } + + private async Task AuthenticateBearer(BearerAuthenticationArgs bearer) + { + throw new NotImplementedException(); + } + + private async Task AuthenticatePin(PinAuthenticationArgs pin) + { + if (await UserService.Resolve(new UserResolveArgs { Criteria = pin.User }) is not IUserPassport user) + return AuthenticationResult.Fail(AuthenticationResultReason.NotFound); + + if (Validate(user, false) is IAuthenticationResult result && !result.Success) + return result; + + if (!CryptographyService.Verify(pin.Pin, user.Pin)) + return AuthenticationResult.Fail(AuthenticationResultReason.InvalidCredentials); + + return AuthenticationResult.OK(user, new JwtSecurityTokenHandler().WriteToken(CreateToken(user))); + } + + private async Task AuthenticateBasic(BasicAuthenticationArgs basic) + { + if (await UserService.Resolve(new UserResolveArgs { Criteria = basic.User }) is not IUserPassport user) + return AuthenticationResult.Fail(AuthenticationResultReason.NotFound); + + if (Validate(user, true) is IAuthenticationResult result && !result.Success) + return result; + + if (!CryptographyService.Verify(basic.Password, user.Password)) + return AuthenticationResult.Fail(AuthenticationResultReason.InvalidPassword); + + return AuthenticationResult.OK(user, new JwtSecurityTokenHandler().WriteToken(CreateToken(user))); + } + + private static IAuthenticationResult? Validate(IUserPassport user, bool validatePassword) + { + if (user.Status == UserStatus.Inactive) + return AuthenticationResult.Fail(AuthenticationResultReason.Inactive); + + if (user.Status == UserStatus.Locked) + return AuthenticationResult.Fail(AuthenticationResultReason.Locked); + + if (validatePassword) + { + if (user.Password is null || !user.Password.Any()) + return AuthenticationResult.Fail(AuthenticationResultReason.NoPassword); + + if (user.PasswordExpiration != DateTime.MinValue && user.PasswordExpiration < DateTime.UtcNow) + return AuthenticationResult.Fail(AuthenticationResultReason.PasswordExpired); + } + + return null; + } + + private JwtSecurityToken CreateToken(IUserPassport user) + { + var config = ConfigurationService.Authentication.JwToken; + var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(config.Key)); + var cred = new SigningCredentials(key, SecurityAlgorithms.HmacSha256); + var claims = new[] { new Claim(ClaimTypes.NameIdentifier, user.AuthenticationToken.ToString()) }; + + return new JwtSecurityToken(issuer: config.Issuer, audience: config.Audience, claims: claims, + expires: DateTime.Now.AddDays(Math.Max(1, config.Duration)), signingCredentials: cred); + } +} diff --git a/Connected.Security/Authentication/Middleware/IAuthenticationMiddleware.cs b/Connected.Security/Authentication/Middleware/IAuthenticationMiddleware.cs new file mode 100644 index 0000000..b0ecfc0 --- /dev/null +++ b/Connected.Security/Authentication/Middleware/IAuthenticationMiddleware.cs @@ -0,0 +1,33 @@ +namespace Connected.Security.Authentication.Middleware; + +/// +/// Represents an authentication middleware which is called when authenticating identity. +/// +/// +/// The environment should have only one registered which +/// uniquely authenticates all requests. The environment have a built in authentication middleware +/// which authenticates identities against standard entities. If you need to authenticate identities +/// against external data sources you need to implement your own and +/// register it in the configuration. +/// +public interface IAuthenticationMiddleware : IMiddleware +{ + /// + /// Performs authentication against identity. Identity is not necessarry a user, it can be device, process + /// or any other entity which needs protection. + /// + /// The authentication arguments explaining the identity. The following identities are supported by a + /// default authentication middleware: + /// + /// + /// + /// + /// + /// + /// + /// + /// containing information about authentication wether it is successfull or not and + /// the reason why it was not successfully. If it's successfully it also provides identity. + /// + Task Authenticate(AuthenticationArgs args); +} diff --git a/Connected.Security/Authentication/PinAuthenticationArgs.cs b/Connected.Security/Authentication/PinAuthenticationArgs.cs new file mode 100644 index 0000000..c919af6 --- /dev/null +++ b/Connected.Security/Authentication/PinAuthenticationArgs.cs @@ -0,0 +1,30 @@ +using System.ComponentModel.DataAnnotations; + +namespace Connected.Security.Authentication +{ + /// + /// Represents the authentication args when authenticating users by their Pin + /// code. + /// + /// + /// This is slightly different from a authorization + /// used in cases where users have duals security needs, the password for the direct access to + /// the environment, for example by using login screen on the front end and the pin code used + /// by security cards used by alternative authentication protocols. + /// + public sealed class PinAuthenticationArgs : AuthenticationArgs + { + /// + /// The user's identity. It can contain login name, email or authentication token. + /// + [Required] + [MaxLength(128)] + public string User { get; init; } + /// + /// The user's Pin code. + /// + [Required] + [MaxLength(32)] + public string Pin { get; init; } + } +} diff --git a/Connected.Security/Authentication/SsoAuthenticationArgs.cs b/Connected.Security/Authentication/SsoAuthenticationArgs.cs new file mode 100644 index 0000000..e49ce7f --- /dev/null +++ b/Connected.Security/Authentication/SsoAuthenticationArgs.cs @@ -0,0 +1,24 @@ +using System.ComponentModel.DataAnnotations; + +namespace Connected.Security.Authentication +{ + /// + /// Represents the Single Sign On arguments. + /// + /// + /// Once users are successfully authenticated against the identity the + /// authentication token is returned by the . + /// This token can be later used by any client to perform authentication and get the user's identity. + /// + public sealed class SsoAuthenticationArgs : AuthenticationArgs + { + /// + /// The security token issued by the . The system + /// can match the user's identity based on this value. Tat means the token is unique + /// across the entire environment. + /// + [Required] + [MaxLength(128)] + public string Token { get; init; } + } +} diff --git a/Connected.Security/Authorization/AuthorizationArgs.cs b/Connected.Security/Authorization/AuthorizationArgs.cs new file mode 100644 index 0000000..f8cc45a --- /dev/null +++ b/Connected.Security/Authorization/AuthorizationArgs.cs @@ -0,0 +1,13 @@ +namespace Connected.Security.Authorization +{ + public class AuthorizationArgs : EventArgs + { + public IAuthorizationSchema? Schema { get; init; } + public int? User { get; init; } + public object? PrimaryKey { get; init; } + public string Claim { get; init; } + public string? Entity { get; init; } + public string? Component { get; init; } + public string? Method { get; init; } + } +} diff --git a/Connected.Security/Authorization/AuthorizationResult.cs b/Connected.Security/Authorization/AuthorizationResult.cs new file mode 100644 index 0000000..80517b0 --- /dev/null +++ b/Connected.Security/Authorization/AuthorizationResult.cs @@ -0,0 +1,14 @@ +namespace Connected.Security.Authorization +{ + public sealed class AuthorizationResult : IAuthorizationResult + { + public bool Success { get; init; } + + public AuthorizationResultReason Reason { get; init; } + + public int PermissionCount { get; init; } + + public static AuthorizationResult OK() => new() { Success = true, Reason = AuthorizationResultReason.OK }; + public static AuthorizationResult Fail(AuthorizationResultReason reason) => new() { Success = false, Reason = reason }; + } +} diff --git a/Connected.Security/Authorization/AuthorizationSchema.cs b/Connected.Security/Authorization/AuthorizationSchema.cs new file mode 100644 index 0000000..7ccb8f7 --- /dev/null +++ b/Connected.Security/Authorization/AuthorizationSchema.cs @@ -0,0 +1,8 @@ +namespace Connected.Security.Authorization +{ + public sealed class AuthorizationSchema : IAuthorizationSchema + { + public EmptyPolicyBehavior EmptyPolicy { get; set; } = EmptyPolicyBehavior.Deny; + public AuthorizationStrategy Strategy { get; set; } = AuthorizationStrategy.Pessimistic; + } +} diff --git a/Connected.Security/Authorization/AuthorizationService.cs b/Connected.Security/Authorization/AuthorizationService.cs new file mode 100644 index 0000000..814be75 --- /dev/null +++ b/Connected.Security/Authorization/AuthorizationService.cs @@ -0,0 +1,143 @@ +using Connected.Middleware; +using Connected.Security.Authorization.Middleware; +using Connected.Security.Permissions; +using Connected.Threading; +using System.Collections.Immutable; + +namespace Connected.Security.Authorization; + +internal class AuthorizationService : IAuthorizationService +{ + public AuthorizationService(IMiddlewareService middleware, IPermissionService permissions) + { + Middleware = new AsyncLazy>(middleware.Query()); + Permissions = permissions; + } + + private IPermissionService Permissions { get; } + + private AsyncLazy> Middleware { get; } + + public async Task Authorize(AuthorizationArgs args) + { + if (args is null) + throw new ArgumentException(null, nameof(args)); + /* + * Claim is required otherwise we don't know what to authorize. + */ + if (string.IsNullOrWhiteSpace(args.Claim)) + return AuthorizationResult.Fail(AuthorizationResultReason.NoClaim); + + /* + * Query permissions for the specified arguments. + */ + var permissions = await Permissions.Query(new PermissionSearchArgs + { + Entity = args.Entity, + PrimaryKey = args.PrimaryKey?.ToString(), + Claim = args.Claim + }); + /* + * We'll use the transition state between calls for better performance. + */ + var state = new Dictionary(); + /* + * The first state is to find out if we need to perform the actual authorization. + * If at least one of the providers returns success from this stage, the authorization + * will immediately return. For example, if the user has Full Control role there is no + * need to perform full authorization since the user has access to any resource. + */ + foreach (var i in await Middleware.Value) + { + if (await i.PreAuthorize(args, state) == AuthorizationProviderResult.Success) + return AuthorizationResult.OK(); + } + /* + * If there are not permissions found we'll look into the arguments configuration. Optimistic scenarios + * will succeed authorization, pessimistic not. + */ + if (permissions is null || !permissions.Any()) + { + return args.Schema.EmptyPolicy switch + { + EmptyPolicyBehavior.Deny => AuthorizationResult.Fail(AuthorizationResultReason.Empty), + EmptyPolicyBehavior.Alow => AuthorizationResult.OK(), + _ => throw new NotSupportedException(), + }; + } + /* + * We have permissions, let's do the authorization + */ + var denyFound = false; + var allowFound = false; + /* + * We must authorize permission by permission. + */ + foreach (var permission in permissions) + { + /* + * We'll ask every provider to tell us what he thinks about each + * permission and then, based on its result, we're gonna decide + * what to do. + */ + foreach (var provider in await Middleware.Value) + { + /* + * Perform the authorization by the provider. + */ + var result = await provider.Authorize(permission, args, state); + /* + * Mark its result for the decision. + */ + switch (result) + { + case AuthorizationProviderResult.Success: + allowFound = true; + break; + case AuthorizationProviderResult.Fail: + denyFound = true; + break; + } + /* + * The first scenarios is that authorization succedded. If the strategy is also optimistic + * it means we were looking for the first successfull authorization. If found, the authorization + * completes successfully. + * + * The second scenarios is oposite. The authorization failed and the strategy is pessimistic which means + * we were looking for the first unsuccessful authorization. If found, the authorization completes unsuccessfully. + */ + if (result == AuthorizationProviderResult.Success && args.Schema.Strategy == AuthorizationStrategy.Optimistic) + return AuthorizationResult.OK(); + else if (result == AuthorizationProviderResult.Fail && args.Schema.Strategy == AuthorizationStrategy.Pessimistic) + return AuthorizationResult.Fail(AuthorizationResultReason.Other); + } + } + /* + * No authorization check directly decided the result so the last stage is to determine the overall result. + */ + switch (args.Schema.Strategy) + { + case AuthorizationStrategy.Pessimistic: + /* + * Pessimistic scenario where not unsuccessful authorization occured we need at least one allow. If there are + * no allow permissions the authorization fail. + */ + if (allowFound) + return AuthorizationResult.OK(); + else + return AuthorizationResult.Fail(AuthorizationResultReason.NoAllowFound); + case AuthorizationStrategy.Optimistic: + /* + * Optimistic scenario, where no successful authorizations occured we are looking for the explicit + * deny authorizations. If found, the authorization fails. + */ + if (denyFound) + return AuthorizationResult.Fail(AuthorizationResultReason.DenyFound); + else + return AuthorizationResult.OK(); + default: + throw new NotSupportedException(); + } + + } +} diff --git a/Connected.Security/Authorization/IAuthorizationContext.cs b/Connected.Security/Authorization/IAuthorizationContext.cs new file mode 100644 index 0000000..65a0c9b --- /dev/null +++ b/Connected.Security/Authorization/IAuthorizationContext.cs @@ -0,0 +1,26 @@ +using Connected.ServiceModel; + +namespace Connected.Security.Authorization; + +public enum AuthorizationContextState +{ + Pending = 0, + Authorizing = 1, + Granted = 2, + Revoked = 3 +} + +public interface IAuthorizationContext +{ + AuthorizationContextState State { get; } + + Task Authorize(AuthorizationArgs args); + + Task Authorize(ICallerContext context, TArgs args) + where TArgs : IDto; + + Task Authorize(ICallerContext context, TArgs args, TComponent component) + where TArgs : IDto; + + void Revoke(); +} diff --git a/Connected.Security/Authorization/IAuthorizationResult.cs b/Connected.Security/Authorization/IAuthorizationResult.cs new file mode 100644 index 0000000..c8be507 --- /dev/null +++ b/Connected.Security/Authorization/IAuthorizationResult.cs @@ -0,0 +1,19 @@ +namespace Connected.Security.Authorization +{ + public enum AuthorizationResultReason + { + OK = 0, + Empty = 1, + NoPrimaryKey = 2, + NoAllowFound = 3, + DenyFound = 4, + NoClaim = 5, + Other = 99 + } + + public interface IAuthorizationResult + { + bool Success { get; } + AuthorizationResultReason Reason { get; } + } +} diff --git a/Connected.Security/Authorization/IAuthorizationSchema.cs b/Connected.Security/Authorization/IAuthorizationSchema.cs new file mode 100644 index 0000000..1fe7830 --- /dev/null +++ b/Connected.Security/Authorization/IAuthorizationSchema.cs @@ -0,0 +1,20 @@ +namespace Connected.Security.Authorization +{ + public enum EmptyPolicyBehavior + { + Deny = 1, + Alow = 2 + } + + public enum AuthorizationStrategy + { + Pessimistic = 1, + Optimistic = 2 + } + + public interface IAuthorizationSchema + { + EmptyPolicyBehavior EmptyPolicy { get; set; } + AuthorizationStrategy Strategy { get; set; } + } +} diff --git a/Connected.Security/Authorization/IAuthorizationService.cs b/Connected.Security/Authorization/IAuthorizationService.cs new file mode 100644 index 0000000..2f7fa5e --- /dev/null +++ b/Connected.Security/Authorization/IAuthorizationService.cs @@ -0,0 +1,7 @@ +namespace Connected.Security.Authorization +{ + public interface IAuthorizationService + { + Task Authorize(AuthorizationArgs args); + } +} diff --git a/Connected.Security/Authorization/Middleware/IAuthorizationMiddleware.cs b/Connected.Security/Authorization/Middleware/IAuthorizationMiddleware.cs new file mode 100644 index 0000000..bb54e8c --- /dev/null +++ b/Connected.Security/Authorization/Middleware/IAuthorizationMiddleware.cs @@ -0,0 +1,18 @@ +using System.Collections.Immutable; +using Connected.Security.Permissions; + +namespace Connected.Security.Authorization.Middleware; + +public enum AuthorizationProviderResult +{ + Success = 1, + Fail = 2, + NotHandled = 3 +} +public interface IAuthorizationMiddleware : IMiddleware +{ + string Id { get; } + Task PreAuthorize(AuthorizationArgs args, Dictionary state); + Task Authorize(IPermission permission, AuthorizationArgs args, Dictionary state); + Task> QueryDescriptors(); +} diff --git a/Connected.Security/Authorization/Middleware/RoleAuthorizationMiddleware.cs b/Connected.Security/Authorization/Middleware/RoleAuthorizationMiddleware.cs new file mode 100644 index 0000000..ced594f --- /dev/null +++ b/Connected.Security/Authorization/Middleware/RoleAuthorizationMiddleware.cs @@ -0,0 +1,119 @@ +using Connected.Interop; +using Connected.Middleware; +using Connected.Security.Identity; +using Connected.Security.Membership; +using Connected.Security.Permissions; +using Connected.ServiceModel; +using Microsoft.Extensions.Logging; +using System.Collections.Immutable; + +namespace Connected.Security.Authorization.Middleware; + +internal class RoleAuthorizationMiddleware : MiddlewareComponent, IAuthorizationMiddleware +{ + public RoleAuthorizationMiddleware(IRoleService roleService, IUserService userService, ILogger logger, IMembershipService membershipService) + { + RoleService = roleService; + UserService = userService; + Logger = logger; + MembershipService = membershipService; + } + + public string Id => "Roles"; + private IRoleService RoleService { get; } + private IUserService UserService { get; } + private ILogger Logger { get; } + private IMembershipService MembershipService { get; } + + public Task Authorize(IPermission permission, AuthorizationArgs e, Dictionary state) + { + if (state["roles"] is not List roles) + return Task.FromResult(AuthorizationProviderResult.NotHandled); + + if (!TypeConversion.TryConvert(permission.Evidence, out int evidence)) + return Task.FromResult(AuthorizationProviderResult.NotHandled); + + if (roles.Contains(evidence)) + { + return permission.Value switch + { + PermissionValue.NotSet => Task.FromResult(AuthorizationProviderResult.NotHandled), + PermissionValue.Allow => Task.FromResult(AuthorizationProviderResult.Success), + PermissionValue.Deny => Task.FromResult(AuthorizationProviderResult.Fail), + _ => throw new NotSupportedException(), + }; + } + + return Task.FromResult(AuthorizationProviderResult.NotHandled); + } + + public async Task PreAuthorize(AuthorizationArgs e, Dictionary state) + { + var roles = await ResolveImplicitRoles(e); + + state.Add("roles", roles); + + if (e.User > 0) + { + var list = await MembershipService.Query(new MembershipQueryArgs { User = (int)e.User }); + + if (list is not null && list.Any()) + roles.AddRange(list.Select(f => f.Role)); + } + + if (await RoleService.Select(new NameArgs { Name = Roles.FullControl }) is IRole fullControl && roles.Contains(fullControl.Id)) + return AuthorizationProviderResult.Success; + + return AuthorizationProviderResult.NotHandled; + } + + public async Task> QueryDescriptors() + { + if (await RoleService.Query() is not ImmutableList roles) + return ImmutableList.Empty; + + var result = new List(); + + foreach (var role in roles) + { + result.Add(new PermissionSchemaDescriptor + { + Title = role.Name, + Id = role.Id.ToString() + }); + } + + return result.ToImmutableList(); + } + + private async Task> ResolveImplicitRoles(AuthorizationArgs e) + { + var result = new List(); + + if (await RoleService.Select(Roles.Everyone) is IRole everyone) + result.Add(everyone.Id); + + if (e.User == 0) + { + if (await RoleService.Select(Roles.Anonymous) is IRole anonymous) + result.Add(anonymous.Id); + } + else + { + if (e.User is null || await UserService.Select(e.User) is null) + { + if (await RoleService.Select(Roles.Anonymous) is IRole anonymous) + result.Add(anonymous.Id); + + Logger.LogWarning("Authenticated user not found. Request will be treated as anonymous {user}.", e.User); + + return result; + } + + if (await RoleService.Select(Roles.Authenticated) is IRole authenticated) + result.Add(authenticated.Id); + } + + return result; + } +} \ No newline at end of file diff --git a/Connected.Security/Authorization/Middleware/UserAuthorizationMiddleware.cs b/Connected.Security/Authorization/Middleware/UserAuthorizationMiddleware.cs new file mode 100644 index 0000000..2546f88 --- /dev/null +++ b/Connected.Security/Authorization/Middleware/UserAuthorizationMiddleware.cs @@ -0,0 +1,63 @@ +using Connected.Middleware; +using Connected.Security.Identity; +using Connected.Security.Permissions; +using System.Collections.Immutable; + +namespace Connected.Security.Authorization.Middleware; + +internal class UserAuthorizationMiddleware : MiddlewareComponent, IAuthorizationMiddleware +{ + public UserAuthorizationMiddleware(IUserService userService) + { + UserService = userService; + } + + public string Id => "Users"; + + public IUserService UserService { get; } + + public Task Authorize(IPermission permission, AuthorizationArgs args, Dictionary state) + { + if (!string.Equals(args.User.ToString(), permission.Evidence, StringComparison.OrdinalIgnoreCase)) + return Task.FromResult(AuthorizationProviderResult.NotHandled); + + switch (permission.Value) + { + case PermissionValue.NotSet: + return Task.FromResult(AuthorizationProviderResult.NotHandled); + case PermissionValue.Allow: + return Task.FromResult(AuthorizationProviderResult.Success); + case PermissionValue.Deny: + return Task.FromResult(AuthorizationProviderResult.Fail); + default: + throw new NotSupportedException(); + } + } + + public Task PreAuthorize(AuthorizationArgs args, Dictionary state) + { + return Task.FromResult(AuthorizationProviderResult.NotHandled); + } + + public async Task> QueryDescriptors() + { + var users = await UserService.Query(); + var r = new List(); + + foreach (var i in users) + { + r.Add(new PermissionSchemaDescriptor + { + Id = i.Id.ToString(), + Title = i.DisplayName(), + Description = i.Email + /* + * TODO: handle avatar + */ + + }); + } + + return r.ToImmutableList(); + } +} diff --git a/Connected.Security/Connected.Security.csproj b/Connected.Security/Connected.Security.csproj new file mode 100644 index 0000000..1ec86a6 --- /dev/null +++ b/Connected.Security/Connected.Security.csproj @@ -0,0 +1,34 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Security/Cryptography/CryptographyService.cs b/Connected.Security/Cryptography/CryptographyService.cs new file mode 100644 index 0000000..292039d --- /dev/null +++ b/Connected.Security/Cryptography/CryptographyService.cs @@ -0,0 +1,47 @@ +using System.Security.Cryptography; +using System.Text; + +namespace Connected.Security.Cryptography +{ + internal class CryptographyService : ICryptographyService + { + public byte[]? Hash(string value) + { + if (string.IsNullOrEmpty(value)) + return null; + + using var md = MD5.Create(); + + return GetHash(md, value); + } + + private static byte[] GetHash(MD5 hash, string value) + { + return hash.ComputeHash(Encoding.UTF8.GetBytes(value)); + } + + public bool Verify(string value, byte[] existing) + { + if (value is null && existing is null) + return false; + + if (!value.Any() && !existing.Any()) + return false; + + using var md = MD5.Create(); + + var hash = GetHash(md, value); + + if (value.Length != hash.Length) + return false; + + for (var i = 0; i < value.Length; i++) + { + if (value[i] != hash[i]) + return false; + } + + return true; + } + } +} diff --git a/Connected.Security/Cryptography/ICryptographyService.cs b/Connected.Security/Cryptography/ICryptographyService.cs new file mode 100644 index 0000000..c0a95fd --- /dev/null +++ b/Connected.Security/Cryptography/ICryptographyService.cs @@ -0,0 +1,8 @@ +namespace Connected.Security.Cryptography +{ + public interface ICryptographyService + { + byte[]? Hash(string value); + bool Verify(string value, byte[] existing); + } +} diff --git a/Connected.Security/Identity/IIdentityService.cs b/Connected.Security/Identity/IIdentityService.cs new file mode 100644 index 0000000..5717740 --- /dev/null +++ b/Connected.Security/Identity/IIdentityService.cs @@ -0,0 +1,8 @@ +namespace Connected.Security.Identity +{ + public interface IIdentityService + { + IUser? CurrentUser { get; } + bool IsAuthenticated { get; } + } +} diff --git a/Connected.Security/Identity/IRole.cs b/Connected.Security/Identity/IRole.cs new file mode 100644 index 0000000..dae9ef6 --- /dev/null +++ b/Connected.Security/Identity/IRole.cs @@ -0,0 +1,8 @@ +using Connected.Data; + +namespace Connected.Security.Identity; + +public interface IRole : IPrimaryKey +{ + string Name { get; init; } +} diff --git a/Connected.Security/Identity/IRoleService.cs b/Connected.Security/Identity/IRoleService.cs new file mode 100644 index 0000000..2f48e87 --- /dev/null +++ b/Connected.Security/Identity/IRoleService.cs @@ -0,0 +1,31 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.ServiceModel; + +namespace Connected.Security.Identity; + +[Service] +[ServiceUrl(SecurityRoutes.Roles)] +public interface IRoleService +{ + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(); + + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(PrimaryKeyListArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(NameArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task Insert(RoleArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Patch)] + Task Update(RoleUpdateArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Security/Identity/IUser.cs b/Connected.Security/Identity/IUser.cs new file mode 100644 index 0000000..f0fc0d9 --- /dev/null +++ b/Connected.Security/Identity/IUser.cs @@ -0,0 +1,65 @@ +using Connected.Data; + +namespace Connected.Security.Identity; + +/// +/// Specifies the user's status. +/// +public enum UserStatus +{ + /// + /// User is not active and cannot log into the environment. + /// + Inactive = 0, + /// + /// User is a valid user and can log into the environment and + /// use its identity. + /// + Active = 1, + /// + /// User is locked out and cannot log into the environment. + /// + Locked = 2 +} +/// +/// Represents user entity. User is basic artifact of the identity +/// infrastructure. +/// +/// +/// User tipically maps to a person or a physical end user of the +/// environment. +/// +public interface IUser : IPrimaryKey +{ + /// + /// The user's first name. + /// + string? FirstName { get; init; } + /// + /// The user's last name + /// + string? LastName { get; init; } + /// + /// The login name which can be used when authenticating the user. + /// + string? LoginName { get; init; } + /// + /// The email associated with the user. The email should be unique + /// across the environment. + /// + string? Email { get; init; } + /// + /// Timezone used when representing Date and Time values. UTC timezone + /// is used by default, if this property is not set. + /// + string? TimeZone { get; init; } + /// + /// Language used by the user as defined in the + /// service. If not set, default, environment wide language is used. + /// + int Language { get; init; } + /// + /// The status of the user indicating wether user can log into the environment or not. + /// + UserStatus Status { get; init; } +} diff --git a/Connected.Security/Identity/IUserPassport.cs b/Connected.Security/Identity/IUserPassport.cs new file mode 100644 index 0000000..65be476 --- /dev/null +++ b/Connected.Security/Identity/IUserPassport.cs @@ -0,0 +1,11 @@ +namespace Connected.Security.Identity +{ + public interface IUserPassport : IUser + { + byte[] Password { get; } + byte[] Pin { get; } + + DateTime PasswordExpiration { get; } + Guid AuthenticationToken { get; } + } +} diff --git a/Connected.Security/Identity/IUserService.cs b/Connected.Security/Identity/IUserService.cs new file mode 100644 index 0000000..08a6e19 --- /dev/null +++ b/Connected.Security/Identity/IUserService.cs @@ -0,0 +1,41 @@ +using Connected.Annotations; +using Connected.ServiceModel; +using System.Collections.Immutable; + +namespace Connected.Security.Identity; + +[Service] +[ServiceUrl(SecurityRoutes.Users)] +public interface IUserService +{ + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(); + + + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(PrimaryKeyListArgs e); + + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Resolve(UserResolveArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task Insert(UserInsertArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Patch)] + Task Update(UserUpdateArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task UpdatePassword(UserPasswordArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Security/Identity/RoleArgs.cs b/Connected.Security/Identity/RoleArgs.cs new file mode 100644 index 0000000..b164812 --- /dev/null +++ b/Connected.Security/Identity/RoleArgs.cs @@ -0,0 +1,18 @@ +using System.ComponentModel.DataAnnotations; +using Connected; +using Connected.Annotations; + +namespace Connected.Security.Identity; + +public class RoleArgs : IDto +{ + [Required] + [MaxLength(128)] + public string Name { get; init; } +} + +public class RoleUpdateArgs : RoleArgs +{ + [MinValue(0)] + public int Id { get; init; } +} diff --git a/Connected.Security/Identity/Roles.cs b/Connected.Security/Identity/Roles.cs new file mode 100644 index 0000000..d51fbcb --- /dev/null +++ b/Connected.Security/Identity/Roles.cs @@ -0,0 +1,10 @@ +namespace Connected.Security.Identity +{ + public static class Roles + { + public const string FullControl = "Full Control"; + public const string Authenticated = "Authenticated"; + public const string Anonymous = "Anonymous"; + public const string Everyone = "Everyone"; + } +} diff --git a/Connected.Security/Identity/UserArgs.cs b/Connected.Security/Identity/UserArgs.cs new file mode 100644 index 0000000..fed6cd0 --- /dev/null +++ b/Connected.Security/Identity/UserArgs.cs @@ -0,0 +1,29 @@ +using Connected.ServiceModel; + +namespace Connected.Security.Identity; + +public class UserInsertArgs : Dto +{ + public string? FirstName { get; init; } + public string? LastName { get; init; } + public string? LoginName { get; init; } + public string? Email { get; init; } + public string? TimeZone { get; init; } + public int Language { get; init; } +} + +public sealed class UserUpdateArgs : UserInsertArgs +{ + public int Id { get; init; } +} + +public sealed class UserResolveArgs : Dto +{ + public string? Criteria { get; init; } +} + +public sealed class UserPasswordArgs : PrimaryKeyArgs +{ + public string? ExistingPassword { get; init; } + public string? NewPassword { get; init; } +} diff --git a/Connected.Security/Identity/UserIdentity.cs b/Connected.Security/Identity/UserIdentity.cs new file mode 100644 index 0000000..291852f --- /dev/null +++ b/Connected.Security/Identity/UserIdentity.cs @@ -0,0 +1,50 @@ +using System.Security.Claims; + +namespace Connected.Security.Identity +{ + public class UserIdentity : ClaimsIdentity + { + private bool _isAuthenticated = true; + private List _claims = null; + + public UserIdentity(IUser user) : this(user, null) + { + } + + public UserIdentity(IUser user, string jwToken) + { + User = user; + Token = jwToken; + //Name = user.AuthenticationToken; + } + + public override string AuthenticationType => "Tom PIT"; + public override bool IsAuthenticated { get { return _isAuthenticated; } } + public override string Name { get; } + public string Token { get; } + + public IUser User { get; } + + public static UserIdentity NotAuthenticated() + { + return new UserIdentity(null, null) + { + _isAuthenticated = false + }; + } + + public override IEnumerable Claims => _claims ??= CreateClaims(); + + private List CreateClaims() + { + //TODO: resolve claims + //using var ctx = Context.Create(); + //var svc = ctx.GetService(); + + //var isAdmin = User is not null && svc.IsInRole(User.Id, Role.FullControl); + + + return new List(); + } + } +} diff --git a/Connected.Security/Membership/IMembership.cs b/Connected.Security/Membership/IMembership.cs new file mode 100644 index 0000000..98bdef0 --- /dev/null +++ b/Connected.Security/Membership/IMembership.cs @@ -0,0 +1,21 @@ +using Connected.Data; + +namespace Connected.Security.Membership; + +/// +/// Represents a user's membership. +/// +/// +/// can belong to one or more . This relations is defined via . +/// +public interface IMembership : IPrimaryKey +{ + /// + /// The id of the . + /// + int User { get; } + /// + /// The id of the . + /// + int Role { get; } +} diff --git a/Connected.Security/Membership/IMembershipService.cs b/Connected.Security/Membership/IMembershipService.cs new file mode 100644 index 0000000..4115f34 --- /dev/null +++ b/Connected.Security/Membership/IMembershipService.cs @@ -0,0 +1,25 @@ +using System.Collections.Immutable; +using Connected.Annotations; +using Connected.ServiceModel; + +namespace Connected.Security.Membership; + +[Service] +[ServiceUrl(SecurityRoutes.Membership)] +public interface IMembershipService +{ + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(); + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task?> Query(MembershipQueryArgs args); + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task Insert(MembershipArgs args); + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Security/Membership/MembershipArgs.cs b/Connected.Security/Membership/MembershipArgs.cs new file mode 100644 index 0000000..f2690e9 --- /dev/null +++ b/Connected.Security/Membership/MembershipArgs.cs @@ -0,0 +1,15 @@ +using Connected.ServiceModel; + +namespace Connected.Security.Membership; + +public class MembershipArgs : Dto +{ + public int User { get; init; } + public int Role { get; init; } +} + +public sealed class MembershipQueryArgs : Dto +{ + public int User { get; init; } + public int Role { get; init; } +} diff --git a/Connected.Security/Permissions/IPermission.cs b/Connected.Security/Permissions/IPermission.cs new file mode 100644 index 0000000..29202ba --- /dev/null +++ b/Connected.Security/Permissions/IPermission.cs @@ -0,0 +1,78 @@ +using Connected.Data; + +namespace Connected.Security.Permissions; + +/// +/// Specifies the state of each permission entry. +/// +public enum PermissionValue +{ + /// + /// Permission is not set on the entry. This is a default value + /// of each permission entry. + /// + NotSet = 0, + /// + /// Evidence does have a claim for the specified resource. + /// + Allow = 1, + /// + /// Evidence does not have a claim for the specified resource. + /// + Deny = 2 +} +/// +/// Represents the permission entry for the specific resource. +/// +/// +/// Environment's assets are protected by . The implementation of each +/// policy is based on the which usually provides the Action, +/// which can be set to assets. The most common assets are methods. Assets or +/// define the which along with Action represents the basics of the permission. +/// The implementation contains the logic what claims are needed to perform each action. Additionally, +/// policy tipically provides a set of claims on which permissions can be set. Permissions are based on descriptors, which can be +/// User, Role or any other registered implementation of the interface. Descriptor provides a set of +/// schemas, usually users and roles and that concludes the permission's component model. +/// +public interface IPermission : IPrimaryKey +{ + /// + /// The id of the evidence to which permission is bound to. This is + /// typically provided by . + /// + string Evidence { get; } + /// + /// The type of the evidence to which permission is bound to. This is + /// typically provided by . + /// + string Schema { get; } + /// + /// The claim to which permission is bound to. This is typically + /// provided by . + /// + string Claim { get; } + /// + /// The primary key of the entity. Can be null if permission is not record based. + /// + string? PrimaryKey { get; } + /// + /// The entity to which permission is bound to. Can be null if permission is + /// environment wide and not bound to a specific entity. + /// + string? Entity { get; } + /// + /// The actual value of the permission. + /// + PermissionValue Value { get; } + /// + /// The component to which permission is bound to. This is important for advanced + /// permission models, for example where admins require the specific permission to be + /// set on a specific service method but the policy is shared between many different + /// services. + /// + string? Component { get; } + /// + /// The component's method for advanced permission models. + /// + string? Method { get; } +} diff --git a/Connected.Security/Permissions/IPermissionSchemaDescriptor.cs b/Connected.Security/Permissions/IPermissionSchemaDescriptor.cs new file mode 100644 index 0000000..1ecbe16 --- /dev/null +++ b/Connected.Security/Permissions/IPermissionSchemaDescriptor.cs @@ -0,0 +1,10 @@ +namespace Connected.Security.Permissions +{ + public interface IPermissionSchemaDescriptor + { + string Id { get; } + string Title { get; } + string Avatar { get; } + string Description { get; } + } +} diff --git a/Connected.Security/Permissions/IPermissionService.cs b/Connected.Security/Permissions/IPermissionService.cs new file mode 100644 index 0000000..fc8f9d6 --- /dev/null +++ b/Connected.Security/Permissions/IPermissionService.cs @@ -0,0 +1,33 @@ +using Connected.Annotations; +using Connected.ServiceModel; +using System.Collections.Immutable; + +namespace Connected.Security.Permissions; + +[Service] +[ServiceUrl(SecurityRoutes.Permissions)] +public interface IPermissionService +{ + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(); + + + [ServiceMethod(ServiceMethodVerbs.Get)] + Task?> Query(PermissionSearchArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Get | ServiceMethodVerbs.Post)] + Task Select(PrimaryKeyArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post)] + Task Insert(PermissionArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Patch)] + Task Update(PermissionUpdateArgs args); + + + [ServiceMethod(ServiceMethodVerbs.Post | ServiceMethodVerbs.Delete)] + Task Delete(PrimaryKeyArgs args); +} diff --git a/Connected.Security/Permissions/PermissionArgs.cs b/Connected.Security/Permissions/PermissionArgs.cs new file mode 100644 index 0000000..3c20fd1 --- /dev/null +++ b/Connected.Security/Permissions/PermissionArgs.cs @@ -0,0 +1,28 @@ +using Connected.ServiceModel; + +namespace Connected.Security.Permissions; + +public class PermissionArgs : Dto +{ + public string Evidence { get; init; } + public string Schema { get; init; } + public string Claim { get; init; } + public string? Descriptor { get; init; } + public string? PrimaryKey { get; init; } + public string? Entity { get; init; } + public PermissionValue Value { get; init; } + public string? Component { get; init; } + public string? Method { get; init; } +} + +public class PermissionUpdateArgs : PrimaryKeyArgs +{ + public PermissionValue Value { get; init; } +} + +public class PermissionSearchArgs : Dto +{ + public string? Entity { get; init; } + public string? Claim { get; init; } + public string? PrimaryKey { get; init; } +} diff --git a/Connected.Security/Permissions/PermissionSchemaDescriptor.cs b/Connected.Security/Permissions/PermissionSchemaDescriptor.cs new file mode 100644 index 0000000..d662928 --- /dev/null +++ b/Connected.Security/Permissions/PermissionSchemaDescriptor.cs @@ -0,0 +1,13 @@ +namespace Connected.Security.Permissions +{ + internal class PermissionSchemaDescriptor : IPermissionSchemaDescriptor + { + public string Id { get; init; } + + public string Title { get; init; } + + public string Avatar { get; init; } + + public string Description { get; init; } + } +} diff --git a/Connected.Security/SR.Designer.cs b/Connected.Security/SR.Designer.cs new file mode 100644 index 0000000..4c62e88 --- /dev/null +++ b/Connected.Security/SR.Designer.cs @@ -0,0 +1,99 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Security { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Security.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Authorization failed. + /// + internal static string PolicyAuthorizationFailed { + get { + return ResourceManager.GetString("PolicyAuthorizationFailed", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Entity exists. + /// + internal static string ValExists { + get { + return ResourceManager.GetString("ValExists", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Invalid password. + /// + internal static string ValInvalidPassword { + get { + return ResourceManager.GetString("ValInvalidPassword", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to System roles are read only. + /// + internal static string ValSysRole { + get { + return ResourceManager.GetString("ValSysRole", resourceCulture); + } + } + } +} diff --git a/Connected.Security/SR.resx b/Connected.Security/SR.resx new file mode 100644 index 0000000..348ae0e --- /dev/null +++ b/Connected.Security/SR.resx @@ -0,0 +1,132 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Authorization failed + + + Entity exists + + + Invalid password + + + System roles are read only + + \ No newline at end of file diff --git a/Connected.Security/SecurityClaims.cs b/Connected.Security/SecurityClaims.cs new file mode 100644 index 0000000..28c76a5 --- /dev/null +++ b/Connected.Security/SecurityClaims.cs @@ -0,0 +1,11 @@ +namespace Connected.Security +{ + public static class SecurityClaims + { + public const string SecurityDelete = "Security Delete"; + public const string SecurityAdd = "Security Add"; + public const string SecurityModify = "Security Modify"; + public const string SecurityModifySelf = "Security Modify Self"; + public const string SecurityRead = "Security Read"; + } +} diff --git a/Connected.Security/SecurityExtensions.cs b/Connected.Security/SecurityExtensions.cs new file mode 100644 index 0000000..304403b --- /dev/null +++ b/Connected.Security/SecurityExtensions.cs @@ -0,0 +1,31 @@ +using Connected.Security.Identity; + +namespace Connected.Security; + +public static class SecurityExtensions +{ + public static string DisplayName(this IUser user) + { + if (user is null) + return null; + + string fn = FullName(user.FirstName, user.LastName); + + if (!string.IsNullOrWhiteSpace(fn)) + return fn; + + if (!string.IsNullOrWhiteSpace(user.LoginName)) + return user.LoginName; + + if (user.Email?.Contains('@') == true) + return user.Email[..user.Email.IndexOf('@')]; + + return user.Id.ToString(); + } + + public static string FullName(string firstName, string lastName) + { + return $"{firstName} {lastName}".Trim(); + } + +} diff --git a/Connected.Security/SecurityRoutes.cs b/Connected.Security/SecurityRoutes.cs new file mode 100644 index 0000000..0978ae9 --- /dev/null +++ b/Connected.Security/SecurityRoutes.cs @@ -0,0 +1,10 @@ +namespace Connected.Security +{ + public static class SecurityRoutes + { + public const string Roles = "/security/roles"; + public const string Users = "/security/users"; + public const string Membership = "/security/membership"; + public const string Permissions = "/security/permissions"; + } +} diff --git a/Connected.Security/SecurityStartup.cs b/Connected.Security/SecurityStartup.cs new file mode 100644 index 0000000..a40ebcb --- /dev/null +++ b/Connected.Security/SecurityStartup.cs @@ -0,0 +1,22 @@ +using Connected.Annotations; +using Connected.Security.Authentication.Middleware; +using Connected.Security.Authorization; +using Connected.Security.Cryptography; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Security; + +internal class SecurityStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddSingleton(typeof(ICryptographyService), typeof(CryptographyService)); + + services.AddScoped(typeof(IAuthenticationMiddleware), typeof(DefaultAuthenticationMiddleware)); + services.AddScoped(typeof(Authentication.IAuthenticationService), typeof(Authentication.AuthenticationService)); + services.AddScoped(typeof(IAuthenticationMiddleware), typeof(DefaultAuthenticationMiddleware)); + services.AddScoped(typeof(IAuthorizationService), typeof(AuthorizationService)); + } +} diff --git a/Connected.Services/Annotations/ServiceAuthorizationAttribute.cs b/Connected.Services/Annotations/ServiceAuthorizationAttribute.cs new file mode 100644 index 0000000..a8f1014 --- /dev/null +++ b/Connected.Services/Annotations/ServiceAuthorizationAttribute.cs @@ -0,0 +1,35 @@ +namespace Connected.Services.Annotations +{ + public enum AuthorizationPolicyBehavior + { + Mandatory = 1, + Optional = 2 + } + + public enum AuthorizationStage + { + Init = 1, + Result = 2 + } + + internal enum EnumOperation + { + HigherThan = 1, + AtLeast = 2, + } + + + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + public sealed class ServiceAuthorizationAttribute : Attribute + { + public ServiceAuthorizationAttribute(params string[] claims) + { + if (claims.Any()) + Claims = new List(claims); + } + public AuthorizationPolicyBehavior Behavior { get; set; } = AuthorizationPolicyBehavior.Mandatory; + public int Priority { get; set; } + public AuthorizationStage Stage { get; set; } = AuthorizationStage.Init; + public List Claims { get; } + } +} diff --git a/Connected.Services/ArgumentValueProvider.cs b/Connected.Services/ArgumentValueProvider.cs new file mode 100644 index 0000000..32b0cad --- /dev/null +++ b/Connected.Services/ArgumentValueProvider.cs @@ -0,0 +1,32 @@ +using Connected.Middleware; + +namespace Connected.Services; +/// +/// Default implementation of the interface. +/// +/// The type of the arguments to provide values for. +public abstract class ArgumentValueProvider : MiddlewareComponent, IArgumentValueProvider + where TArgs : IDto +{ + /// + /// The arguments object to which values are to be provided. + /// + protected TArgs Arguments { get; private set; } = default!; + /// + /// This method gets invoked by the platform. + /// + /// The arguments object to which the values can be provided. + public async Task Invoke(TArgs args) + { + Arguments = args; + + await OnInvoking(); + } + /// + /// Override this method to perform business logic. + /// + protected virtual async Task OnInvoking() + { + await Task.CompletedTask; + } +} diff --git a/Connected.Services/Authorization/DefaultServiceAuthorizationMiddleware.cs b/Connected.Services/Authorization/DefaultServiceAuthorizationMiddleware.cs new file mode 100644 index 0000000..d3cd69d --- /dev/null +++ b/Connected.Services/Authorization/DefaultServiceAuthorizationMiddleware.cs @@ -0,0 +1,11 @@ +using Connected.Security.Authorization; +using Connected.Security.Identity; + +namespace Connected.Services.Authorization; + +internal sealed class DefaultServiceAuthorizationMiddleware : ServiceAuthorizationMiddleware +{ + public DefaultServiceAuthorizationMiddleware(IAuthorizationService authorizationService, IIdentityService identityService) : base(authorizationService, identityService) + { + } +} diff --git a/Connected.Services/Authorization/IServiceAuthorizationMiddleware.cs b/Connected.Services/Authorization/IServiceAuthorizationMiddleware.cs new file mode 100644 index 0000000..32ae324 --- /dev/null +++ b/Connected.Services/Authorization/IServiceAuthorizationMiddleware.cs @@ -0,0 +1,14 @@ +using System.Collections.Immutable; + +namespace Connected.Services.Authorization; + +public interface IServiceAuthorizationMiddleware : IMiddleware +{ + Task> ResolveClaims(ImmutableArray claims); + + Task Authorize(ServiceAuthorizationMiddlewareArgs args) + where TArgs : IDto; + + Task Authorize(ServiceAuthorizationMiddlewareArgs args, TEntity entity) + where TArgs : IDto; +} diff --git a/Connected.Services/Authorization/ServiceAuthorizationContext.cs b/Connected.Services/Authorization/ServiceAuthorizationContext.cs new file mode 100644 index 0000000..e182349 --- /dev/null +++ b/Connected.Services/Authorization/ServiceAuthorizationContext.cs @@ -0,0 +1,217 @@ +using Connected.Interop; +using Connected.Middleware; +using Connected.Security.Authorization; +using Connected.ServiceModel; +using Connected.Services.Annotations; +using Connected.Threading; +using System.Collections.Immutable; +using System.Reflection; + +namespace Connected.Services.Authorization; + +internal sealed class ServiceAuthorizationContext : IAuthorizationContext +{ + public ServiceAuthorizationContext(IAuthorizationService authorization, IMiddlewareService middleware) + { + Authorization = authorization; + + Middleware = new AsyncLazy>(middleware.Query()); + } + + private IAuthorizationService Authorization { get; } + private AsyncLazy> Middleware { get; } + public AuthorizationContextState State { get; private set; } = AuthorizationContextState.Pending; + + public async Task Authorize(AuthorizationArgs args) + { + /* + * Authorization is performed only if the context does not have a grant permission + * for the execution. That means only the first authorization request is actually + * performed. Once the caller has been granted access no authorization takes place + * afterwards. + */ + if (State == AuthorizationContextState.Granted) + return AuthorizationResult.OK(); + + var result = await Authorization.Authorize(args); + /* + * Authorization completed. Set correct state which means no + * authorization is needed for any subsequent calls. + */ + if (result.Success) + State = AuthorizationContextState.Granted; + else + State = AuthorizationContextState.Revoked; + + return result; + } + + public async Task Authorize(ICallerContext context, TArgs args) where TArgs : IDto + { + /* + * We allow only one authorizaton process in the same context at the same time. + * We would cause a potential stack overflow otherwise. + */ + if (State == AuthorizationContextState.Authorizing) + return; + + if (State == AuthorizationContextState.Granted) + return; + + State = AuthorizationContextState.Authorizing; + + try + { + if (context.Sender is null) + return; + + if (ResolveAttributes(context, args) is not List attributes || !attributes.Any()) + return; + + var staging = attributes.Where(f => f.Stage == AuthorizationStage.Init); + + if (!staging.Any()) + return; + + Exception? firstFail = null; + bool onePassed = false; + + foreach (var attribute in staging) + { + try + { + if (attribute.Behavior == AuthorizationPolicyBehavior.Optional && onePassed) + continue; + + var claims = attribute.Claims.ToImmutableArray(); + + foreach (var middleware in await Middleware.Value) + claims = await middleware.ResolveClaims(claims); + + var middlewareArgs = new ServiceAuthorizationMiddlewareArgs(context, claims, args); + + foreach (var middleware in await Middleware.Value) + await middleware.Authorize(middlewareArgs); + + onePassed = true; + } + catch (Exception ex) + { + if (attribute.Behavior == AuthorizationPolicyBehavior.Mandatory) + throw; + + firstFail = ex; + } + } + + if (!onePassed && firstFail is not null) + throw firstFail; + + State = AuthorizationContextState.Granted; + } + catch + { + State = AuthorizationContextState.Revoked; + + throw; + } + } + + /// + /// Results are always authorized, event if is . + /// + /// + /// + /// + /// + /// + /// + public async Task Authorize(ICallerContext context, TArgs args, TComponent component) + where TArgs : IDto + { + if (context.Sender is null) + return component; + + if (ResolveAttributes(context, args) is not List attributes || !attributes.Any()) + return component; + + var staging = attributes.Where(f => f.Stage == AuthorizationStage.Result); + + if (!staging.Any()) + return component; + + Exception? firstFail = null; + bool onePassed = false; + var result = component; + + foreach (var attribute in staging) + { + try + { + if (attribute.Behavior == AuthorizationPolicyBehavior.Optional && onePassed) + continue; + + var claims = attribute.Claims.ToImmutableArray(); + + foreach (var middleware in await Middleware.Value) + claims = await middleware.ResolveClaims(claims); + + var middlewareArgs = new ServiceAuthorizationMiddlewareArgs(context, claims, args); + + foreach (var middleware in await Middleware.Value) + result = await middleware.Authorize(middlewareArgs, result); + + onePassed = true; + } + catch (Exception ex) + { + if (attribute.Behavior == AuthorizationPolicyBehavior.Mandatory) + throw; + + firstFail = ex; + } + } + + if (!onePassed && firstFail is not null) + throw firstFail; + + return result; + } + + public void Revoke() + { + if (State == AuthorizationContextState.Granted) + State = AuthorizationContextState.Revoked; + } + + private static List? ResolveAttributes(ICallerContext context, TArgs args) where TArgs : IDto + { + if (context.Sender is null || string.IsNullOrEmpty(context.Method)) + return default; + + if (context.Sender.GetType().ResolveMethod(context.Method, null, new Type[1] { args.GetType() }) is not MethodInfo method) + throw new InvalidOperationException($"{SR.ErrMethodResolve} ({context.Sender.GetType().Name}.{context.Method})"); + + var attributes = method.GetCustomAttributes(typeof(ServiceAuthorizationAttribute), false); + + if (attributes is null || !attributes.Any()) + return null; + + var result = new List(); + + foreach (var attribute in attributes) + result.Add((ServiceAuthorizationAttribute)attribute); + + result.Sort((left, right) => + { + if (left.Priority < right.Priority) + return -1; + else if (left.Priority == right.Priority) + return 0; + else + return 1; + }); + + return result; + } +} diff --git a/Connected.Services/Authorization/ServiceAuthorizationMiddleware.cs b/Connected.Services/Authorization/ServiceAuthorizationMiddleware.cs new file mode 100644 index 0000000..02c29d5 --- /dev/null +++ b/Connected.Services/Authorization/ServiceAuthorizationMiddleware.cs @@ -0,0 +1,94 @@ +using Connected.Entities; +using Connected.Middleware; +using Connected.Security.Authorization; +using Connected.Security.Identity; +using System.Collections.Immutable; +using System.Security; + +namespace Connected.Services.Authorization; + +public abstract class ServiceAuthorizationMiddleware : MiddlewareComponent, IServiceAuthorizationMiddleware +{ + protected ServiceAuthorizationMiddleware(IAuthorizationService authorizationService, IIdentityService identityService) + { + AuthorizationService = authorizationService; + IdentityService = identityService; + } + + protected IAuthorizationService AuthorizationService { get; } + protected IIdentityService IdentityService { get; } + + public Task> ResolveClaims(ImmutableArray claims) + { + return OnResolveClaims(claims); + } + + protected virtual Task> OnResolveClaims(ImmutableArray claims) + { + return Task.FromResult(claims); + } + + public async Task Authorize(ServiceAuthorizationMiddlewareArgs args) where TArgs : IDto + { + await OnAuthorize(args); + } + + protected virtual async Task OnAuthorize(ServiceAuthorizationMiddlewareArgs args) where TArgs : IDto + { + foreach (var claim in args.Claims) + { + var result = await AuthorizationService.Authorize(new AuthorizationArgs + { + User = IdentityService.CurrentUser?.Id, + Claim = claim, + Component = args.Context.Sender.GetType().Name, + Method = args.Context.Method, + Schema = new AuthorizationSchema + { + EmptyPolicy = EmptyPolicyBehavior.Alow, + Strategy = AuthorizationStrategy.Pessimistic + } + }); + + if (!result.Success) + throw new AccessViolationException($"{SR.PolicyAuthorizationFailed} ({result.Reason})"); + } + } + + public async Task Authorize(ServiceAuthorizationMiddlewareArgs args, TComponent component) + where TArgs : IDto + { + return await OnAuthorize(args, component); + } + + protected async Task OnAuthorize(ServiceAuthorizationMiddlewareArgs args, TComponent component) + where TArgs : IDto + { + var entity = component as IEntity; + + var primaryKey = entity is null ? component.ToString() : entity.PrimaryKeyValue(); + + foreach (var claim in args.Claims) + { + var result = await AuthorizationService.Authorize(new AuthorizationArgs + { + User = IdentityService.CurrentUser?.Id, + Claim = claim, + Component = args.Context.Sender.GetType().Name, + Method = args.Context.Method, + Schema = new AuthorizationSchema + { + EmptyPolicy = EmptyPolicyBehavior.Alow, + Strategy = AuthorizationStrategy.Pessimistic + }, + Entity = entity.EntityId(), + PrimaryKey = primaryKey + }); + + if (!result.Success) + throw new SecurityException($"{SR.PolicyAuthorizationFailed} ({result.Reason})"); + } + + return component; + } +} diff --git a/Connected.Services/Authorization/ServiceAuthorizationMiddlewareArgs.cs b/Connected.Services/Authorization/ServiceAuthorizationMiddlewareArgs.cs new file mode 100644 index 0000000..37cf313 --- /dev/null +++ b/Connected.Services/Authorization/ServiceAuthorizationMiddlewareArgs.cs @@ -0,0 +1,18 @@ +using System.Collections.Immutable; +using Connected.ServiceModel; + +namespace Connected.Services.Authorization; + +public sealed class ServiceAuthorizationMiddlewareArgs : EventArgs +{ + public ServiceAuthorizationMiddlewareArgs(ICallerContext context, ImmutableArray claims, TArgs args) + { + Context = context; + Claims = claims; + Args = args; + } + + public ICallerContext Context { get; } + public ImmutableArray Claims { get; } + public TArgs Args { get; } +} diff --git a/Connected.Services/Authorization/ServiceAuthorizationResult.cs b/Connected.Services/Authorization/ServiceAuthorizationResult.cs new file mode 100644 index 0000000..c05e592 --- /dev/null +++ b/Connected.Services/Authorization/ServiceAuthorizationResult.cs @@ -0,0 +1,21 @@ +namespace Connected.Services.Authorization +{ + public class ServiceAuthorizationResult + { + public ServiceAuthorizationResult() + { + Message = SR.PolicyAuthorizationFailed; + } + public ServiceAuthorizationResult(string message) + { + Message = message; + } + + public string Message { get; set; } + + public static ServiceAuthorizationResult Default(Attribute sender, object policy) + { + return new ServiceAuthorizationResult($"{SR.PolicyAuthorizationFailed} ({sender.GetType().Name}.{policy})"); + } + } +} diff --git a/Connected.Services/Connected.Services.csproj b/Connected.Services/Connected.Services.csproj new file mode 100644 index 0000000..1559ed0 --- /dev/null +++ b/Connected.Services/Connected.Services.csproj @@ -0,0 +1,38 @@ + + + + net7.0 + enable + enable + + + + + + + + + + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Services/DistributedService.cs b/Connected.Services/DistributedService.cs new file mode 100644 index 0000000..fc85359 --- /dev/null +++ b/Connected.Services/DistributedService.cs @@ -0,0 +1,40 @@ +using Connected.Net; +using Connected.Net.Server; +using Connected.ServiceModel; + +namespace Connected.Services; + +public abstract class DistributedService : Service +{ + public DistributedService(IContext context) : base(context) + { + if (context is null) + throw new ArgumentNullException(null, nameof(context)); + + EndpointServer = context.GetService(); + + if (EndpointServer is null) + throw new NullReferenceException(nameof(IEndpointServer)); + + Http = context.GetService(); + + if (Http is null) + throw new NullReferenceException(nameof(IHttpService)); + } + + private IEndpointServer EndpointServer { get; } + + protected async Task ParseUrl(string relativePath) + { + if (await IsServer()) + throw new InvalidOperationException(SR.ErrNoServer); + + return $"{EndpointServer.ServerUrl}/{relativePath.Trim('/')}"; + } + + protected IHttpService Http { get; } + protected async Task IsServer() + { + return await EndpointServer.IsServer(); + } +} diff --git a/Connected.Services/EntityService.cs b/Connected.Services/EntityService.cs new file mode 100644 index 0000000..dbf62c4 --- /dev/null +++ b/Connected.Services/EntityService.cs @@ -0,0 +1,39 @@ +using Connected.Notifications; +using Connected.Notifications.Events; +using Connected.ServiceModel; + +namespace Connected.Services; + +public abstract class EntityService : Service, IServiceNotifications +{ + public event ServiceEventHandler>? Inserted; + public event ServiceEventHandler>? Updated; + public event ServiceEventHandler>? Deleted; + + protected EntityService(IContext context) : base(context) + { + if (context.GetService() is IEventService events) + events.Event += OnEvent; + } + + private void OnEvent(IOperationState? sender, EventServiceArgs? e) + { + if (!e.Service.GetType().IsAssignableTo(GetType())) + return; + + if (string.Equals(e.Event, nameof(Inserted), StringComparison.Ordinal) && e.Arguments is PrimaryKeyEventArgs iargs) + Inserted?.Invoke(sender, iargs); + else if (string.Equals(e.Event, nameof(Updated), StringComparison.Ordinal) && e.Arguments is PrimaryKeyEventArgs uargs) + Inserted?.Invoke(sender, uargs); + else if (string.Equals(e.Event, nameof(Deleted), StringComparison.Ordinal) && e.Arguments is PrimaryKeyEventArgs dargs) + Inserted?.Invoke(sender, dargs); + } + + protected override void OnDisposing() + { + if (Context.GetService() is IEventService events) + events.Event -= OnEvent; + + base.OnDisposing(); + } +} diff --git a/Connected.Services/IAction.cs b/Connected.Services/IAction.cs new file mode 100644 index 0000000..7019a91 --- /dev/null +++ b/Connected.Services/IAction.cs @@ -0,0 +1,9 @@ +using Connected; + +namespace Connected.Services; + +public interface IAction : IServiceOperation + where TArgs : IDto +{ + Task Invoke(TArgs e); +} diff --git a/Connected.Services/IArgumentValueProvider.cs b/Connected.Services/IArgumentValueProvider.cs new file mode 100644 index 0000000..0171a77 --- /dev/null +++ b/Connected.Services/IArgumentValueProvider.cs @@ -0,0 +1,23 @@ +namespace Connected.Services; +/// +/// Middleware representing opportunity to set or modify values on the objects. +/// +/// +/// Some arguments provide properties that are not mandatory by the caller but must be set before +/// the operation is executed. This Middleware is called before the Validation phase occurs. +/// +/// +/// Serial value of the Stock item is not provided by the client but is needed before the goods can be +/// stored in the stock. The platform expects that a process will provide it before the Validation +/// occurs. Depending of the process implementation, it can create a new Serial or use existing one. +/// +/// The type of the arguments to be used by the middleware. +public interface IArgumentValueProvider : IMiddleware + where TArgs : IDto +{ + /// + /// This method gets called by the platform at the time when the values should be provided. + /// + /// The arguments instance on which values can be provided. + Task Invoke(TArgs args); +} diff --git a/Connected.Services/IDistributedService.cs b/Connected.Services/IDistributedService.cs new file mode 100644 index 0000000..cdec288 --- /dev/null +++ b/Connected.Services/IDistributedService.cs @@ -0,0 +1,6 @@ +namespace Connected.Services +{ + public interface IDistributedService + { + } +} diff --git a/Connected.Services/IFunction.cs b/Connected.Services/IFunction.cs new file mode 100644 index 0000000..4187971 --- /dev/null +++ b/Connected.Services/IFunction.cs @@ -0,0 +1,7 @@ +namespace Connected.Services; + +public interface IFunction : IServiceOperation + where TArgs : IDto +{ + Task Invoke(TArgs e); +} diff --git a/Connected.Services/INullableFunction.cs b/Connected.Services/INullableFunction.cs new file mode 100644 index 0000000..ed9ac26 --- /dev/null +++ b/Connected.Services/INullableFunction.cs @@ -0,0 +1,6 @@ +namespace Connected.Services; +public interface INullableFunction : IServiceOperation + where TArgs : IDto +{ + Task Invoke(TArgs e); +} diff --git a/Connected.Services/IService.cs b/Connected.Services/IService.cs new file mode 100644 index 0000000..03c3384 --- /dev/null +++ b/Connected.Services/IService.cs @@ -0,0 +1,7 @@ +namespace Connected.Services +{ + public interface IService + { + + } +} diff --git a/Connected.Services/IServiceOperation.cs b/Connected.Services/IServiceOperation.cs new file mode 100644 index 0000000..0d1ecf9 --- /dev/null +++ b/Connected.Services/IServiceOperation.cs @@ -0,0 +1,9 @@ +using Connected; + +namespace Connected.Services; + +public interface IServiceOperation + where TArgs : IDto +{ + +} diff --git a/Connected.Services/Middleware/ActionMiddleware.cs b/Connected.Services/Middleware/ActionMiddleware.cs new file mode 100644 index 0000000..1daf95d --- /dev/null +++ b/Connected.Services/Middleware/ActionMiddleware.cs @@ -0,0 +1,24 @@ +namespace Connected.Services.Middleware +{ + public abstract class ActionMiddleware : ServiceMiddleware, IActionMiddleware + { + protected TArgs? Arguments { get; private set; } + public async Task Invoke(TArgs? args) + { + Arguments = args; + + await OnValidate(); + await OnInvoke(); + } + + protected virtual async Task OnInvoke() + { + await Task.CompletedTask; + } + + protected virtual async Task OnValidate() + { + await Task.CompletedTask; + } + } +} diff --git a/Connected.Services/Middleware/FunctionMiddleware.cs b/Connected.Services/Middleware/FunctionMiddleware.cs new file mode 100644 index 0000000..69ee150 --- /dev/null +++ b/Connected.Services/Middleware/FunctionMiddleware.cs @@ -0,0 +1,26 @@ +namespace Connected.Services.Middleware +{ + public abstract class FunctionMiddleware : ServiceMiddleware, IFunctionMiddleware + { + protected TArgs? Arguments { get; private set; } + public async Task Invoke(TArgs? args, TReturnValue? result) + { + Arguments = args; + + await OnValidate(); + return await OnInvoke(result); + } + + protected virtual async Task OnInvoke(TReturnValue? result) + { + await Task.CompletedTask; + + return result; + } + + protected virtual async Task OnValidate() + { + await Task.CompletedTask; + } + } +} diff --git a/Connected.Services/Middleware/IActionMiddleware.cs b/Connected.Services/Middleware/IActionMiddleware.cs new file mode 100644 index 0000000..0cb3dc3 --- /dev/null +++ b/Connected.Services/Middleware/IActionMiddleware.cs @@ -0,0 +1,7 @@ +namespace Connected.Services.Middleware +{ + public interface IActionMiddleware : IServiceMiddleware + { + Task Invoke(TArgs? args); + } +} diff --git a/Connected.Services/Middleware/IFunctionMiddleware.cs b/Connected.Services/Middleware/IFunctionMiddleware.cs new file mode 100644 index 0000000..3b9ed7e --- /dev/null +++ b/Connected.Services/Middleware/IFunctionMiddleware.cs @@ -0,0 +1,7 @@ +namespace Connected.Services.Middleware +{ + public interface IFunctionMiddleware : IServiceMiddleware + { + Task Invoke(TArgs? args, TReturnValue? result); + } +} diff --git a/Connected.Services/Middleware/IServiceMiddleware.cs b/Connected.Services/Middleware/IServiceMiddleware.cs new file mode 100644 index 0000000..ec5668e --- /dev/null +++ b/Connected.Services/Middleware/IServiceMiddleware.cs @@ -0,0 +1,7 @@ +namespace Connected.Services.Middleware; + +public interface IServiceMiddleware : IMiddleware +{ + Task Commit(); + Task Rollback(); +} diff --git a/Connected.Services/Middleware/ServiceMiddleware.cs b/Connected.Services/Middleware/ServiceMiddleware.cs new file mode 100644 index 0000000..7e8af64 --- /dev/null +++ b/Connected.Services/Middleware/ServiceMiddleware.cs @@ -0,0 +1,27 @@ +using Connected.Middleware; + +namespace Connected.Services.Middleware +{ + public abstract class ServiceMiddleware : MiddlewareComponent, IServiceMiddleware + { + public async Task Commit() + { + await OnCommit(); + } + + protected virtual async Task OnCommit() + { + await Task.CompletedTask; + } + + public async Task Rollback() + { + await OnRollback(); + } + + protected virtual async Task OnRollback() + { + await Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/Connected.Services/NullableServiceFunction.cs b/Connected.Services/NullableServiceFunction.cs new file mode 100644 index 0000000..0f8808f --- /dev/null +++ b/Connected.Services/NullableServiceFunction.cs @@ -0,0 +1,24 @@ +namespace Connected.Services; +public abstract class NullableServiceFunction : ServiceOperation, INullableFunction + where TArgs : IDto +{ + protected TReturnValue? Result { get; private set; } + public async Task Invoke(TArgs args) + { + if (args is null) + throw new ArgumentException(null, nameof(args)); + + Arguments = args; + + Result = await OnInvoke(); + + return Result; + } + + protected virtual async Task OnInvoke() + { + await Task.CompletedTask; + + return default; + } +} diff --git a/Connected.Services/SR.Designer.cs b/Connected.Services/SR.Designer.cs new file mode 100644 index 0000000..757bc43 --- /dev/null +++ b/Connected.Services/SR.Designer.cs @@ -0,0 +1,99 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Services { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Services.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Extender has been created but does not implement IExtender<TArgs> interface. + /// + internal static string ErrExtenderCreate { + get { + return ResourceManager.GetString("ErrExtenderCreate", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Cannot resolve method. + /// + internal static string ErrMethodResolve { + get { + return ResourceManager.GetString("ErrMethodResolve", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Instance is endpoint server. + /// + internal static string ErrNoServer { + get { + return ResourceManager.GetString("ErrNoServer", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Authorization failed. + /// + internal static string PolicyAuthorizationFailed { + get { + return ResourceManager.GetString("PolicyAuthorizationFailed", resourceCulture); + } + } + } +} diff --git a/Connected.Services/SR.resx b/Connected.Services/SR.resx new file mode 100644 index 0000000..1e1977f --- /dev/null +++ b/Connected.Services/SR.resx @@ -0,0 +1,132 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Extender has been created but does not implement IExtender<TArgs> interface + + + Cannot resolve method + + + Instance is endpoint server + + + Authorization failed + + \ No newline at end of file diff --git a/Connected.Services/Service.cs b/Connected.Services/Service.cs new file mode 100644 index 0000000..696e32b --- /dev/null +++ b/Connected.Services/Service.cs @@ -0,0 +1,157 @@ +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Connected.Middleware; +using Connected.Security.Authorization; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; +using Connected.Services.Middleware; +using Connected.Validation; + +namespace Connected.Services; + +public abstract class Service : IService, IDisposable +{ + protected Service(IContext context) + { + Context = context; + + if (Context.GetService() is not IMiddlewareService middleware) + throw new NullReferenceException(nameof(IMiddlewareService)); + + Middleware = middleware; + } + + protected IContext Context { get; } + private IMiddlewareService Middleware { get; } + protected bool IsDisposed { get; private set; } + + protected async Task Invoke(IFunction function, TArgs args, [CallerMemberName] string? method = null) + where TArgs : IDto + { + var ctx = await Prepare(function, args, method); + var result = await function.Invoke(args); + + var middleware = await Middleware.Query>(ctx); + + if (!middleware.IsEmpty) + { + foreach (var m in middleware) + result = await m.Invoke(args, result); + } + + return await Authorize(ctx, args, result); + } + + protected async Task Invoke(INullableFunction function, TArgs args, [CallerMemberName] string? method = null) + where TArgs : IDto + { + var ctx = await Prepare(function, args, method); + var result = await function.Invoke(args); + + var middleware = await Middleware.Query>(ctx); + + if (!middleware.IsEmpty) + { + foreach (var m in middleware) + result = await m.Invoke(args, result); + } + + return await Authorize(ctx, args, result); + } + + protected async Task Invoke(IAction action, TArgs args, [CallerMemberName] string? method = null) + where TArgs : IDto + { + var ctx = await Prepare(action, args, method); + + await action.Invoke(args); + + var middleware = await Middleware.Query>(ctx); + + if (!middleware.IsEmpty) + { + foreach (var m in middleware) + await m.Invoke(args); + } + } + + private async Task Prepare(IServiceOperation operation, TArgs args, [CallerMemberName] string? method = null) + where TArgs : IDto + { + if (operation is ITransactionClient client && Context.GetService() is ITransactionContext transaction) + transaction.Register(client); + + var ctx = new CallerContext(this, method); + + await ProvideArgumentValues(args); + Validate(ctx, args); + await Authorize(ctx, args); + + return ctx; + } + + public TOperation GetOperation([CallerMemberName] string? method = null) + { + return Context.GetService(); + } + + private void Validate(ICallerContext context, TArgs args) + where TArgs : IDto + { + if (Context.GetService() is not IValidationContext validationContext) + return; + + validationContext.Validate(context, args); + } + + private async Task Authorize(ICallerContext context, TArgs args) + where TArgs : IDto + { + if (Context.GetService() is not IAuthorizationContext authorization) + return; + + await authorization.Authorize(context, args); + } + + private async Task Authorize(ICallerContext context, TArgs args, TResult result) + where TArgs : IDto + { + /* + * TODO: call data protection middleware and authorize each entity in the result. + * Will need to implement iterator which resolves all entities in the result set. + */ + await Task.CompletedTask; + + return result; + } + + private async Task ProvideArgumentValues(TArgs args) + where TArgs : IDto + { + if (await Middleware.Query>() is not ImmutableList> middleware || middleware.IsEmpty) + return; + + foreach (var m in middleware) + await m.Invoke(args); + } + + private void Dispose(bool disposing) + { + if (!IsDisposed) + { + if (disposing) + OnDisposing(); + + IsDisposed = true; + } + } + protected virtual void OnDisposing() + { + + } + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/Connected.Services/ServiceAction.cs b/Connected.Services/ServiceAction.cs new file mode 100644 index 0000000..070677f --- /dev/null +++ b/Connected.Services/ServiceAction.cs @@ -0,0 +1,17 @@ +namespace Connected.Services; + +public abstract class ServiceAction : ServiceOperation, IAction + where TArgs : IDto +{ + public async Task Invoke(TArgs e) + { + Arguments = e; + + await OnInvoke(); + } + + protected virtual async Task OnInvoke() + { + await Task.CompletedTask; + } +} diff --git a/Connected.Services/ServiceFunction.cs b/Connected.Services/ServiceFunction.cs new file mode 100644 index 0000000..884ac4f --- /dev/null +++ b/Connected.Services/ServiceFunction.cs @@ -0,0 +1,20 @@ +namespace Connected.Services; + +public abstract class ServiceFunction : ServiceOperation, IFunction + where TArgs : IDto +{ + protected TReturnValue Result { get; private set; } = default!; + public async Task Invoke(TArgs args) + { + if (args is null) + throw new ArgumentException(null, nameof(args)); + + Arguments = args; + + Result = await OnInvoke(); + + return Result; + } + + protected abstract Task OnInvoke(); +} diff --git a/Connected.Services/ServiceOperation.cs b/Connected.Services/ServiceOperation.cs new file mode 100644 index 0000000..6c73182 --- /dev/null +++ b/Connected.Services/ServiceOperation.cs @@ -0,0 +1,94 @@ +using System.Collections.Concurrent; +using Connected.Interop; +using Connected.ServiceModel; +using Connected.ServiceModel.Transactions; + +namespace Connected.Services; + +public abstract class ServiceOperation : IServiceOperation, ITransactionClient, IOperationState + where TArgs : IDto +{ + private TArgs _arguments; + + protected ServiceOperation() + { + State = new(); + } + + private ConcurrentDictionary State { get; } + + public TArgs Arguments + { + get => _arguments; + protected set + { + if (value is null) + throw new ArgumentException(nameof(Arguments)); + + _arguments = value; + } + } + + public TEntity? SetState(TEntity? entity) + { + var key = typeof(TEntity).FullName; + + if (string.IsNullOrEmpty(key)) + return entity; + + State.AddOrUpdate(key, entity, (existing, @new) => + { + return @new; + }); + + return entity; + } + + public TEntity? GetState() + { + var key = typeof(TEntity).FullName; + + if (string.IsNullOrEmpty(key)) + return default; + + if (!State.TryGetValue(key, out object? result)) + return default; + + if (TypeConversion.TryConvert(result, out TEntity? entity)) + return entity; + + return default; + } + + async Task ITransactionClient.Commit() + { + await OnCommitting(); + await OnCommitted(); + } + + async Task ITransactionClient.Rollback() + { + await OnRollingBack(); + await OnRolledBack(); + } + + protected virtual async Task OnCommitted() + { + await Task.CompletedTask; + } + + protected virtual async Task OnRolledBack() + { + await Task.CompletedTask; + } + + protected virtual async Task OnCommitting() + { + await Task.CompletedTask; + } + + protected virtual async Task OnRollingBack() + { + await Task.CompletedTask; + } +} diff --git a/Connected.Services/ServicesExtensions.cs b/Connected.Services/ServicesExtensions.cs new file mode 100644 index 0000000..a84f6b9 --- /dev/null +++ b/Connected.Services/ServicesExtensions.cs @@ -0,0 +1,54 @@ +using System.Reflection; +using Connected.Annotations; +using Connected.Services.Middleware; + +namespace Connected.Services; + +public static class ServicesExtensions +{ + public static List GetImplementedServices(this Type type) + { + var result = new List(); + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (i.GetCustomAttribute() is not null) + result.Add(i); + } + + return result; + } + + public static bool IsService(this Type type) + { + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (i.GetCustomAttribute() is not null) + return true; + } + + return false; + } + + public static bool IsServiceMiddleware(this Type type) + { + return type.GetInterface(typeof(IServiceMiddleware<>).FullName) is not null; + } + + public static bool IsServiceOperation(this Type type) + { + return type.GetInterface(typeof(IServiceOperation<>).FullName) is not null; + } + + public static bool IsServiceFunction(this Type type) + { + var nf = typeof(INullableFunction<,>).FullName; + var f = typeof(IFunction<,>).FullName; + + return (f is not null && type.GetInterface(f) is not null) + || (nf is not null && type.GetInterface(nf) is not null); + } +} \ No newline at end of file diff --git a/Connected.Services/ServicesStartup.cs b/Connected.Services/ServicesStartup.cs new file mode 100644 index 0000000..e3d70b6 --- /dev/null +++ b/Connected.Services/ServicesStartup.cs @@ -0,0 +1,16 @@ +using Connected.Annotations; +using Connected.Security.Authorization; +using Connected.Services.Authorization; +using Microsoft.Extensions.DependencyInjection; + +[assembly: MicroService(MicroServiceType.Sys)] + +namespace Connected.Services; + +internal class ServicesStartup : Startup +{ + protected override void OnConfigureServices(IServiceCollection services) + { + services.AddScoped(typeof(IAuthorizationContext), typeof(ServiceAuthorizationContext)); + } +} diff --git a/Connected.Threading/AsyncLazy.cs b/Connected.Threading/AsyncLazy.cs new file mode 100644 index 0000000..c403fae --- /dev/null +++ b/Connected.Threading/AsyncLazy.cs @@ -0,0 +1,9 @@ +namespace Connected.Threading; +public class AsyncLazy : Lazy> +{ + public AsyncLazy(Task value) + : base(value) + { + + } +} diff --git a/Connected.Threading/AsyncLocker.cs b/Connected.Threading/AsyncLocker.cs new file mode 100644 index 0000000..a52b0a7 --- /dev/null +++ b/Connected.Threading/AsyncLocker.cs @@ -0,0 +1,74 @@ +using System.Collections.Concurrent; +using Connected; + +namespace Connected.Threading; + +public class AsyncLocker : IDisposable +{ + private ConcurrentDictionary _items = new(); + private bool _disposed; + + private ConcurrentDictionary Items => _items; + + public async Task LockAsync(int semaphore, Func worker) + { + if (Items.TryGetValue(semaphore, out AsyncLockerSlim? locker)) + await locker.LockAsync(worker); + + var newLocker = new AsyncLockerSlim(); + + if (Items.TryAdd(semaphore, newLocker)) + await newLocker.LockAsync(worker); + else + { + newLocker.Dispose(); + + if (!Items.TryGetValue(semaphore, out AsyncLockerSlim? existing)) + throw new SysException(this, SR.ErrLock); + + await existing.LockAsync(worker); + } + } + + public async Task LockAsync(int semaphore, Func> worker) + { + if (Items.TryGetValue(semaphore, out AsyncLockerSlim? locker)) + return await locker.LockAsync(worker); + + var newLocker = new AsyncLockerSlim(); + + if (Items.TryAdd(semaphore, newLocker)) + return await newLocker.LockAsync(worker); + else + { + newLocker.Dispose(); + + if (!Items.TryGetValue(semaphore, out AsyncLockerSlim? existing)) + throw new SysException(this, SR.ErrLock); + + return await existing.LockAsync(worker); + } + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + foreach (var slim in Items) + slim.Value.Dispose(); + + Items.Clear(); + } + + _disposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Threading/AsyncLockerSlim.cs b/Connected.Threading/AsyncLockerSlim.cs new file mode 100644 index 0000000..585239d --- /dev/null +++ b/Connected.Threading/AsyncLockerSlim.cs @@ -0,0 +1,58 @@ +namespace Connected.Threading; + +public class AsyncLockerSlim : IDisposable +{ + /* + * this should be upgraded in the future: + * https://stackoverflow.com/questions/24139084/semaphoreslim-waitasync-before-after-try-block/61806749#61806749 + */ + private readonly SemaphoreSlim _semaphore = new(1, 1); + private bool _disposed; + + public async Task LockAsync(Func worker) + { + await _semaphore.WaitAsync(); + + try + { + await worker(); + } + finally + { + _semaphore.Release(); + } + } + + public async Task LockAsync(Func> worker) + { + await _semaphore.WaitAsync(); + + try + { + return await worker(); + } + finally + { + _semaphore.Release(); + } + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + _semaphore.Dispose(); + } + + _disposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/Connected.Threading/Connected.Threading.csproj b/Connected.Threading/Connected.Threading.csproj new file mode 100644 index 0000000..b68714a --- /dev/null +++ b/Connected.Threading/Connected.Threading.csproj @@ -0,0 +1,28 @@ + + + + net7.0 + enable + enable + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Threading/SR.Designer.cs b/Connected.Threading/SR.Designer.cs new file mode 100644 index 0000000..9a1d472 --- /dev/null +++ b/Connected.Threading/SR.Designer.cs @@ -0,0 +1,72 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Threading { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Server.Threading.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Cannot obtain lock. + /// + internal static string ErrLock { + get { + return ResourceManager.GetString("ErrLock", resourceCulture); + } + } + } +} diff --git a/Connected.Threading/SR.resx b/Connected.Threading/SR.resx new file mode 100644 index 0000000..afd08c7 --- /dev/null +++ b/Connected.Threading/SR.resx @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cannot obtain lock + + \ No newline at end of file diff --git a/Connected.Threading/ScheduledTask.cs b/Connected.Threading/ScheduledTask.cs new file mode 100644 index 0000000..1f0a335 --- /dev/null +++ b/Connected.Threading/ScheduledTask.cs @@ -0,0 +1,160 @@ +namespace Connected.Threading; +/// +/// Represents asynchronous task which invokes action on predefined interval. +/// +/// +/// This class is useful when performing tasks which must be completed in specified time but +/// supports pinging or similar techniques which enable tasks that do not complete in time to +/// prolong their execution. +/// +/// +/// If queue task must complete in 30 seconds before becomes visible again, +/// would be useful to act as a guard and prevent queue message to become visible before the processing +/// is completed. We would set the to 25 seconds and call ping which will give +/// us another 30 seconds to complete. +/// +public sealed class ScheduledTask : IDisposable +{ + /// + /// Creates a new instance of a . + /// + /// The action to be called when timeout occurs. + /// The action to be called when task exceeds the execution. + /// The timeout before the scheduledAction is called. + /// The total time that is allowed this task to be run. + /// The cancellation token to cancel the task. + public ScheduledTask(Func scheduledAction, Func expiredAction, TimeSpan timeout, TimeSpan lifetime, CancellationToken cancel) + { + ScheduledAction = scheduledAction; + ExpiredAction = expiredAction; + Timeout = timeout; + Lifetime = lifetime; + CancelSource = new CancellationTokenSource(); + + cancel.Register(CancelSource.Cancel); + } + /// + /// The internal cancellation source which is passed in the task to enable cancellation. + /// + private CancellationTokenSource CancelSource { get; } + /// + /// The action or callback which is called when timeout occurs. + /// + private Func ScheduledAction { get; } + /// + /// This action is called if the executes past . + /// + public Func ExpiredAction { get; } + /// + /// The actual Task which performs the logic. + /// + private Task? Task { get; set; } + /// + /// The timeout before the is called. + /// + private TimeSpan Timeout { get; } + /// + /// The total number that is allowe the task to be run. + /// + public TimeSpan Lifetime { get; } + /// + /// Gets or sets value which determines if the is + /// currently running or not. + /// + private bool IsRunning { get; set; } + /// + /// Starts the if it is not already running. + /// + public void Start() + { + if (IsRunning) + return; + /* + * Create new async task. + */ + Task = Task.Run(async () => + { + var start = DateTime.UtcNow; + + try + { + /* + * We are calling the ScheduledAction repeatedly until the: + * a) cancellation is requested + * b) stop is called which sets IsRunning property to false + */ + while (!CancelSource.IsCancellationRequested || IsRunning) + { + /* + * If Dispose has been called in the meantime Task would be null here + */ + if (Task is not null) + { + /* + * First wait for the duration of the Timeout before triggering callback + */ + await Task.Delay(Timeout, CancelSource.Token).ConfigureAwait(false); + /* + * If the lifetime exceedes stop the task and call ExpiredAction instead on ScheduledAction. + */ + if (DateTime.UtcNow.Subtract(start) > Lifetime) + { + Stop(); + + await ExpiredAction(); + + return; + } + + /* + * Once the Timeout elapsed invoke the callback + */ + await ScheduledAction().ConfigureAwait(false); + } + } + } + finally + { + IsRunning = false; + } + }, CancelSource.Token); + + IsRunning = true; + } + /// + /// Stops the execution loop. + /// + public void Stop() + { + if (CancelSource.IsCancellationRequested) + return; + + try + { + /* + * This will immeadiatelly stop the Task. + */ + CancelSource.Cancel(); + + } + catch (OperationCanceledException) + { + + } + finally + { + IsRunning = false; + } + } + /// + /// Disposes the object by stopping the task if it is running. + /// + public void Dispose() + { + Stop(); + CancelSource.Dispose(); + + Task?.Dispose(); + Task = null; + } +} \ No newline at end of file diff --git a/Connected.Validation/Annotations/SkipValidationAttribute.cs b/Connected.Validation/Annotations/SkipValidationAttribute.cs new file mode 100644 index 0000000..4407ecb --- /dev/null +++ b/Connected.Validation/Annotations/SkipValidationAttribute.cs @@ -0,0 +1,7 @@ +namespace Connected.Validation.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class SkipValidationAttribute : Attribute + { + } +} diff --git a/Connected.Validation/Annotations/ValidateAntiforgeryAttribute.cs b/Connected.Validation/Annotations/ValidateAntiforgeryAttribute.cs new file mode 100644 index 0000000..93c3e4c --- /dev/null +++ b/Connected.Validation/Annotations/ValidateAntiforgeryAttribute.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; + +namespace Connected.Validation.Annotations +{ + [AttributeUsage(AttributeTargets.Class)] + internal class ValidateAntiforgeryAttribute : ValidationAttribute + { + public ValidateAntiforgeryAttribute() + { + + } + + public ValidateAntiforgeryAttribute(bool validateRequest) + { + ValidateRequest = validateRequest; + } + public bool ValidateRequest { get; } + } +} diff --git a/Connected.Validation/Annotations/ValidateRequestAttribute.cs b/Connected.Validation/Annotations/ValidateRequestAttribute.cs new file mode 100644 index 0000000..f52e103 --- /dev/null +++ b/Connected.Validation/Annotations/ValidateRequestAttribute.cs @@ -0,0 +1,13 @@ +namespace Connected.Validation.Annotations +{ + [AttributeUsage(AttributeTargets.Property)] + public sealed class ValidateRequestAttribute : Attribute + { + public ValidateRequestAttribute(bool validate) + { + Validate = validate; + } + + public bool Validate { get; } + } +} diff --git a/Connected.Validation/Connected.Validation.csproj b/Connected.Validation/Connected.Validation.csproj new file mode 100644 index 0000000..ac43842 --- /dev/null +++ b/Connected.Validation/Connected.Validation.csproj @@ -0,0 +1,29 @@ + + + + net7.0 + enable + enable + + + + + + + + + + True + True + SR.resx + + + + + + ResXFileCodeGenerator + SR.Designer.cs + + + + diff --git a/Connected.Validation/IValidationContext.cs b/Connected.Validation/IValidationContext.cs new file mode 100644 index 0000000..91e7bb0 --- /dev/null +++ b/Connected.Validation/IValidationContext.cs @@ -0,0 +1,9 @@ +using Connected.ServiceModel; + +namespace Connected.Validation; + +public interface IValidationContext +{ + void Validate(ICallerContext context, TArgs value) + where TArgs : IDto; +} diff --git a/Connected.Validation/IValidator.cs b/Connected.Validation/IValidator.cs new file mode 100644 index 0000000..100a46d --- /dev/null +++ b/Connected.Validation/IValidator.cs @@ -0,0 +1,6 @@ +namespace Connected.Validation; + +public interface IValidator : IMiddleware +{ + Task Validate(TArgs args); +} diff --git a/Connected.Validation/SR.Designer.cs b/Connected.Validation/SR.Designer.cs new file mode 100644 index 0000000..03ac094 --- /dev/null +++ b/Connected.Validation/SR.Designer.cs @@ -0,0 +1,144 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Connected.Validation { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class SR { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal SR() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Connected.Validation.SR", typeof(SR).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Invalid Antiforgery token. + /// + internal static string ValAntiForgery { + get { + return ResourceManager.GetString("ValAntiForgery", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Duplicate value. + /// + internal static string ValDuplicate { + get { + return ResourceManager.GetString("ValDuplicate", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Entity is disabled. + /// + internal static string ValEntityDisabled { + get { + return ResourceManager.GetString("ValEntityDisabled", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Enum value not defined. + /// + internal static string ValEnumValueNotDefined { + get { + return ResourceManager.GetString("ValEnumValueNotDefined", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Value contains invalid character. + /// + internal static string ValInvalidChars { + get { + return ResourceManager.GetString("ValInvalidChars", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Value mismatch. + /// + internal static string ValMismatch { + get { + return ResourceManager.GetString("ValMismatch", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Not found. + /// + internal static string ValNotFound { + get { + return ResourceManager.GetString("ValNotFound", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Referenced record exists. + /// + internal static string ValReference { + get { + return ResourceManager.GetString("ValReference", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Script tag is not allowed. + /// + internal static string ValScriptTagNotAllowed { + get { + return ResourceManager.GetString("ValScriptTagNotAllowed", resourceCulture); + } + } + } +} diff --git a/Connected.Validation/SR.resx b/Connected.Validation/SR.resx new file mode 100644 index 0000000..395dc8c --- /dev/null +++ b/Connected.Validation/SR.resx @@ -0,0 +1,147 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Invalid Antiforgery token + + + Duplicate value + + + Entity is disabled + + + Enum value not defined + + + Value contains invalid character + + + Value mismatch + + + Not found + + + Referenced record exists + + + Script tag is not allowed + + \ No newline at end of file diff --git a/Connected.Validation/ValidationContext.cs b/Connected.Validation/ValidationContext.cs new file mode 100644 index 0000000..527b773 --- /dev/null +++ b/Connected.Validation/ValidationContext.cs @@ -0,0 +1,248 @@ +using System.Collections; +using System.Collections.Immutable; +using System.ComponentModel.DataAnnotations; +using System.Reflection; +using System.Web; +using Connected.Configuration; +using Connected.Interop; +using Connected.Middleware; +using Connected.Security.Authorization; +using Connected.ServiceModel; +using Connected.Validation.Annotations; +using Microsoft.AspNetCore.Antiforgery; +using Microsoft.AspNetCore.Http; + +namespace Connected.Validation; + +internal class ValidationContext : IValidationContext +{ + public ValidationContext(IContext context, IAuthorizationContext authorization, + IConfigurationService configuration, IHttpContextAccessor http, IAntiforgery antiforgery, + IMiddlewareService middleware) + { + Context = context; + Authorization = authorization; + Configuration = configuration; + Http = http; + Antiforgery = antiforgery; + Middleware = middleware; + } + + private IContext? Context { get; } + private IAuthorizationContext? Authorization { get; } + private IConfigurationService? Configuration { get; } + private IHttpContextAccessor? Http { get; } + private IAntiforgery? Antiforgery { get; } + private IMiddlewareService? Middleware { get; } + + public void Validate(ICallerContext context, TArgs value) + where TArgs : IDto + { + ValidateAntiforgery(value); + + var results = new List(); + var refs = new List(); + + Validate(results, value, refs); + + if (results.Any()) + throw new ValidationException(results[0].ErrorMessage); + + if (Middleware is null) + return; + + if (Middleware.Query>(context) is not ImmutableList> middleware) + return; + + foreach (var m in middleware) + m.Validate(value); + } + + private async void ValidateAntiforgery(object? value) + { + if (Antiforgery is null || value is null) + return; + + if (value.GetType().GetCustomAttribute() is ValidateAntiforgeryAttribute attribute && !attribute.ValidateRequest) + return; + + if (Configuration is null || Configuration.Type != ProcessType.BackEnd) + return; + + if (Http?.HttpContext?.Request is null || !Http.HttpContext.Request.IsAjaxRequest()) + return; + /* + * No need to validate antiforgery more than once. + */ + if (Authorization is not null && Authorization.State == AuthorizationContextState.Granted) + return; + + if (!await Antiforgery.IsRequestValidAsync(Http.HttpContext)) + return; + + throw new ValidationException(SR.ValAntiForgery); + } + + private void Validate(List results, object? value, List references) + { + if (value is null) + return; + + if (value.GetType().IsTypePrimitive()) + return; + + if (value is null || references.Contains(value)) + return; + + references.Add(value); + + var properties = value.GetType().GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + if (!properties.Any()) + return; + + var publicProps = new List(); + var nonPublicProps = new List(); + + foreach (var property in properties) + { + if (property.GetMethod is null) + continue; + + if (property.GetCustomAttribute() is not null) + continue; + + if (property.GetMethod.IsPublic) + publicProps.Add(property); + else + nonPublicProps.Add(property); + } + /* + * First, iterate only through the public properties + * At this point we won't validate complex objects, only the attributes directly on the + * passed instance + */ + foreach (var property in publicProps) + ValidateProperty(results, value, property); + /* + * If root validation failed we won't go deep because this would probably cause + * duplicate and/or confusing validation messages + */ + if (results.Any()) + return; + /* + * Second step is to validate complex public members and collections. + */ + foreach (var property in publicProps) + { + if (property.PropertyType.IsEnumerable()) + { + if (GetValue(value, property) is not IEnumerable ien) + continue; + + var en = ien.GetEnumerator(); + + while (en.MoveNext()) + { + if (en.Current is null) + continue; + + Validate(results, en.Current, references); + } + } + else + { + if (GetValue(value, property) is not object val) + continue; + + Validate(results, val, references); + } + } + /* + * If any complex validation failed we won't validate private members because + * it is possible that initialization would fail for the reason of validation being failed. + */ + if (results.Any()) + return; + /* + * Now that validation of the public properties succeed we can go validate nonpublic members + */ + foreach (var property in nonPublicProps) + ValidateProperty(results, value, property); + } + + private void ValidateProperty(List results, object? value, PropertyInfo property) + { + var attributes = property.GetCustomAttributes(false); + + if (!ValidateRequestValue(results, value, property)) + return; + + if (property.PropertyType.IsEnum && !Enum.TryParse(property.PropertyType, TypeConversion.Convert(property.GetValue(value)), out _)) + results.Add(new ValidationResult($"{SR.ValEnumValueNotDefined} ({property.PropertyType.ShortName()}, {property.GetValue(value)})", new string[] { property.Name })); + + foreach (var attribute in attributes) + { + if (attribute is ValidationAttribute val) + { + var serviceProvider = new ValidationServiceProvider(Context); + var displayName = property.Name; + + var ctx = new System.ComponentModel.DataAnnotations.ValidationContext(value, serviceProvider, new Dictionary()) + { + DisplayName = displayName.ToLower(), + MemberName = property.Name, + }; + + val.Validate(GetValue(value, property), ctx); + } + } + } + + private static bool ValidateRequestValue(List results, PropertyInfo property, object? value) + { + if (value is null) + return true; + + var att = property.FindAttribute(); + + if (att is not null && !att.Validate) + return true; + + if (HttpUtility.HtmlDecode(value.ToString()) is not string decoded) + return true; + + if (decoded.Replace(" ", string.Empty).Contains("