From d7b0a471f078b549d0549ede54419c9bca150211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matija=20Ko=C5=BEelj?= Date: Mon, 5 Dec 2022 17:35:00 +0100 Subject: [PATCH] Rebase with github repository --- Connected.Data/Schema/ExistingColumn.cs | 2 + Connected.Data/Schema/ISchemaColumn.cs | 3 + Connected.Data/Schema/SchemaColumn.cs | 6 +- Connected.Data/Schema/SchemaService.cs | 5 +- .../Schema/Sql/AdHocSchemaEntity.cs | 3 +- .../Schema/Sql/SchemaExecutionContext.cs | 8 +- Connected.Data/Schema/Sql/SpHelp.cs | 272 ++- Connected.Data/Storage/ConnectionProvider.cs | 6 +- Connected.Data/Update/CommandBuilder.cs | 405 ++-- .../Formatters/SqlFormatter.cs | 1954 +++++++++-------- .../Rewriters/WhereClauseRewriter.cs | 30 +- Connected.Instance/EntitySynchronizer.cs | 17 +- Connected.Interop/Generics.cs | 28 + Connected.Interop/Merging/JsonMerger.cs | 2 +- Connected.Interop/Serializer.cs | 41 +- Connected.Interop/TypeSystem.cs | 4 +- Connected.Net/HttpExtensions.cs | 6 +- 17 files changed, 1434 insertions(+), 1358 deletions(-) diff --git a/Connected.Data/Schema/ExistingColumn.cs b/Connected.Data/Schema/ExistingColumn.cs index 9c4efd4..fa3c379 100644 --- a/Connected.Data/Schema/ExistingColumn.cs +++ b/Connected.Data/Schema/ExistingColumn.cs @@ -1,5 +1,6 @@ using System.Collections.Immutable; using System.Data; +using System.Reflection; using Connected.Data.Schema.Sql; using Connected.Entities.Annotations; @@ -44,6 +45,7 @@ internal class ExistingColumn : ISchemaColumn, IExistingSchemaColumn public DateKind DateKind { get; set; } = DateKind.DateTime; public BinaryKind BinaryKind { get; set; } = BinaryKind.VarBinary; + public PropertyInfo Property { get; set; } public int DatePrecision { get; set; } public ImmutableArray QueryIndexColumns(string column) diff --git a/Connected.Data/Schema/ISchemaColumn.cs b/Connected.Data/Schema/ISchemaColumn.cs index d24fcfc..e954e82 100644 --- a/Connected.Data/Schema/ISchemaColumn.cs +++ b/Connected.Data/Schema/ISchemaColumn.cs @@ -1,4 +1,5 @@ using System.Data; +using System.Reflection; using Connected.Entities.Annotations; namespace Connected.Data.Schema; @@ -21,4 +22,6 @@ public interface ISchemaColumn DateKind DateKind { get; } BinaryKind BinaryKind { get; } int DatePrecision { get; } + + PropertyInfo Property { get; } } diff --git a/Connected.Data/Schema/SchemaColumn.cs b/Connected.Data/Schema/SchemaColumn.cs index 77f922d..5115c12 100644 --- a/Connected.Data/Schema/SchemaColumn.cs +++ b/Connected.Data/Schema/SchemaColumn.cs @@ -1,13 +1,15 @@ using System.Data; +using System.Reflection; using Connected.Entities.Annotations; namespace Connected.Data.Schema; internal class SchemaColumn : IEquatable, ISchemaColumn { - public SchemaColumn(ISchema schema) + public SchemaColumn(ISchema schema, PropertyInfo property) { Schema = schema; + Property = property; } private ISchema Schema { get; } @@ -30,6 +32,8 @@ internal class SchemaColumn : IEquatable, ISchemaColumn public int Ordinal { get; set; } + public PropertyInfo Property { get; set; } + public bool Equals(ISchemaColumn? other) { if (other is null) diff --git a/Connected.Data/Schema/SchemaService.cs b/Connected.Data/Schema/SchemaService.cs index 6423c8e..c8600b6 100644 --- a/Connected.Data/Schema/SchemaService.cs +++ b/Connected.Data/Schema/SchemaService.cs @@ -60,6 +60,7 @@ internal class SchemaService : ISchemaService await middleware.Synchronize(entity, schema); synchronized = true; + break; } /* * We should notify the environment that entity is no synchronized. @@ -80,7 +81,7 @@ internal class SchemaService : ISchemaService { var persistence = entityType.GetCustomAttribute(); - return persistence is null || !persistence.Persistence.HasFlag(ColumnPersistence.Write); + return persistence is null || persistence.Persistence.HasFlag(ColumnPersistence.Write); } private static ISchema CreateSchema(Type type) @@ -104,7 +105,7 @@ internal class SchemaService : ISchemaService if (property.FindAttribute() is PersistenceAttribute pa && pa.IsVirtual) continue; - var column = new SchemaColumn(result) + var column = new SchemaColumn(result, property) { Name = ResolveColumnName(property), DataType = DataExtensions.ToDbType(property) diff --git a/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs b/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs index 7aa3d8b..15d3e73 100644 --- a/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs +++ b/Connected.Data/Schema/Sql/AdHocSchemaEntity.cs @@ -4,8 +4,7 @@ using Connected.Entities.Annotations; namespace Connected.Data.Schema.Sql; [Persistence(Persistence = ColumnPersistence.InMemory)] -internal sealed class AdHocSchemaEntity : IEntity +internal sealed record AdHocSchemaEntity : Entity { - public State State { get; init; } public bool Result { get; init; } } diff --git a/Connected.Data/Schema/Sql/SchemaExecutionContext.cs b/Connected.Data/Schema/Sql/SchemaExecutionContext.cs index 01d7097..4898ed4 100644 --- a/Connected.Data/Schema/Sql/SchemaExecutionContext.cs +++ b/Connected.Data/Schema/Sql/SchemaExecutionContext.cs @@ -1,6 +1,6 @@ -using Connected.Data.Sql; +using System.Data; +using Connected.Data.Sql; using Connected.Entities.Storage; -using System.Data; namespace Connected.Data.Schema.Sql; @@ -50,7 +50,7 @@ internal class SchemaExecutionContext public async Task Execute(string commandText) { - await Storage.Open().Select(new SchemaStorageArgs(new StorageOperation { CommandText = commandText }, typeof(SqlDataConnection), ConnectionString)); + await Storage.Open().Execute(new SchemaStorageArgs(new StorageOperation { CommandText = commandText }, typeof(SqlDataConnection), ConnectionString)); } public async Task Select(string commandText) @@ -74,7 +74,7 @@ internal class SchemaExecutionContext Constraints.Add(type, existing); } - if (ConstraintNameExists(name)) + if (!ConstraintNameExists(name)) existing.Add(name); } diff --git a/Connected.Data/Schema/Sql/SpHelp.cs b/Connected.Data/Schema/Sql/SpHelp.cs index 3bfe615..270747b 100644 --- a/Connected.Data/Schema/Sql/SpHelp.cs +++ b/Connected.Data/Schema/Sql/SpHelp.cs @@ -1,145 +1,143 @@ -using Connected.Data.Storage; +using System.Data; using Connected.Entities.Storage; using Connected.Interop; -using System.Data; namespace Connected.Data.Schema.Sql; internal class SpHelp : SynchronizationQuery { - private readonly ObjectDescriptor _descriptor; - - public SpHelp() - { - _descriptor = new(); - } - - private ObjectDescriptor Result => _descriptor; - - protected override async Task OnExecute() - { - var operation = new StorageOperation { CommandText = "sp_help", CommandType = CommandType.Text }; - - operation.AddParameter(new StorageParameter - { - Name = "@objname", - Type = DbType.String, - Value = Escape(Context.Schema.SchemaName(), Context.Schema.Name) - } - ); - - var rdr = await Context.OpenReader(operation); - - try - { - ReadMetadata(rdr); - ReadColumns(rdr); - ReadIdentity(rdr); - ReadRowGuid(rdr); - ReadFileGroup(rdr); - ReadIndexes(rdr); - ReadConstraints(rdr); - } - finally - { - rdr.Close(); - } - - return Result; - } - - private void ReadMetadata(IDataReader rdr) - { - if (rdr.Read()) - { - Result.MetaData.Name = rdr.GetValue("Name", string.Empty); - Result.MetaData.Created = rdr.GetValue("Created_datetime", DateTime.MinValue); - Result.MetaData.Owner = rdr.GetValue("Owner", string.Empty); - Result.MetaData.Type = rdr.GetValue("Type", string.Empty); - } - } - - private void ReadColumns(IDataReader rdr) - { - rdr.NextResult(); - - while (rdr.Read()) - { - Result.Columns.Add(new ObjectColumn - { - Collation = rdr.GetValue("Collation", string.Empty), - Computed = !string.Equals(rdr.GetValue("Computed", string.Empty), "no", StringComparison.OrdinalIgnoreCase), - FixedLenInSource = rdr.GetValue("FixedLenNullInSource", string.Empty), - Length = rdr.GetValue("Length", 0), - Name = rdr.GetValue("Column_name", string.Empty), - Nullable = !string.Equals(rdr.GetValue("Nullable", string.Empty), "no", StringComparison.OrdinalIgnoreCase), - Precision = TypeConversion.Convert(rdr.GetValue("Prec", string.Empty).Trim()), - Scale = TypeConversion.Convert(rdr.GetValue("Scale", string.Empty).Trim()), - TrimTrailingBlanks = rdr.GetValue("TrimTrailingBlanks", string.Empty), - Type = rdr.GetValue("Type", string.Empty) - }); - } - } - - private void ReadIdentity(IDataReader rdr) - { - rdr.NextResult(); - - if (rdr.Read()) - { - Result.Identity.Identity = rdr.GetValue("Identity", string.Empty); - Result.Identity.Increment = rdr.GetValue("Increment", 0); - Result.Identity.NotForReplication = rdr.GetValue("Not For Replication", 0) != 0; - } - } - - private void ReadRowGuid(IDataReader rdr) - { - rdr.NextResult(); - - if (rdr.Read()) - Result.RowGuid.RowGuidCol = rdr.GetValue("RowGuidCol", string.Empty); - } - - private void ReadFileGroup(IDataReader rdr) - { - rdr.NextResult(); - - if (rdr.Read()) - Result.FileGroup.FileGroup = rdr.GetValue("Data_located_on_filegroup", string.Empty); - } - - private void ReadIndexes(IDataReader rdr) - { - rdr.NextResult(); - - while (rdr.Read()) - { - Result.Indexes.Add(new ObjectIndex - { - Description = rdr.GetValue("index_description", string.Empty), - Keys = rdr.GetValue("index_keys", string.Empty), - Name = rdr.GetValue("index_name", string.Empty) - }); - } - } - - private void ReadConstraints(IDataReader rdr) - { - rdr.NextResult(); - - while (rdr.Read()) - { - Result.Constraints.Add(new ObjectConstraint - { - DeleteAction = rdr.GetValue("delete_action", string.Empty), - Keys = rdr.GetValue("constraint_keys", string.Empty), - Name = rdr.GetValue("constraint_name", string.Empty), - StatusEnabled = rdr.GetValue("status_enabled", string.Empty), - StatusForReplication = rdr.GetValue("status_for_replication", string.Empty), - Type = rdr.GetValue("constraint_type", string.Empty), - UpdateAction = rdr.GetValue("update_action", string.Empty) - }); - } - } + private readonly ObjectDescriptor _descriptor; + + public SpHelp() + { + _descriptor = new(); + } + + private ObjectDescriptor Result => _descriptor; + + protected override async Task OnExecute() + { + var operation = new StorageOperation { CommandText = "sp_help", CommandType = CommandType.Text }; + + operation.AddParameter(new StorageParameter + { + Name = "@objname", + Type = DbType.String, + Value = Escape(Context.Schema.SchemaName(), Context.Schema.Name) + } + ); + + var rdr = await Context.OpenReader(operation); + + try + { + ReadMetadata(rdr); + ReadColumns(rdr); + ReadIdentity(rdr); + ReadRowGuid(rdr); + ReadFileGroup(rdr); + ReadIndexes(rdr); + ReadConstraints(rdr); + } + finally + { + rdr.Close(); + } + + return Result; + } + + private void ReadMetadata(IDataReader rdr) + { + if (rdr.Read()) + { + Result.MetaData.Name = rdr.GetValue("Name", string.Empty); + Result.MetaData.Owner = rdr.GetValue("Owner", string.Empty); + Result.MetaData.Type = rdr.GetValue("Object_type", string.Empty); + } + } + + private void ReadColumns(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Columns.Add(new ObjectColumn + { + Collation = rdr.GetValue("Collation", string.Empty), + Computed = !string.Equals(rdr.GetValue("Computed", string.Empty), "no", StringComparison.OrdinalIgnoreCase), + FixedLenInSource = rdr.GetValue("FixedLenNullInSource", string.Empty), + Length = rdr.GetValue("Length", 0), + Name = rdr.GetValue("Column_name", string.Empty), + Nullable = !string.Equals(rdr.GetValue("Nullable", string.Empty), "no", StringComparison.OrdinalIgnoreCase), + Precision = TypeConversion.Convert(rdr.GetValue("Prec", string.Empty).Trim()), + Scale = TypeConversion.Convert(rdr.GetValue("Scale", string.Empty).Trim()), + TrimTrailingBlanks = rdr.GetValue("TrimTrailingBlanks", string.Empty), + Type = rdr.GetValue("Type", string.Empty) + }); + } + } + + private void ReadIdentity(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + { + Result.Identity.Identity = rdr.GetValue("Identity", string.Empty); + Result.Identity.Increment = rdr.GetValue("Increment", 0); + Result.Identity.NotForReplication = rdr.GetValue("Not For Replication", 0) != 0; + } + } + + private void ReadRowGuid(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + Result.RowGuid.RowGuidCol = rdr.GetValue("RowGuidCol", string.Empty); + } + + private void ReadFileGroup(IDataReader rdr) + { + rdr.NextResult(); + + if (rdr.Read()) + Result.FileGroup.FileGroup = rdr.GetValue("Data_located_on_filegroup", string.Empty); + } + + private void ReadIndexes(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Indexes.Add(new ObjectIndex + { + Description = rdr.GetValue("index_description", string.Empty), + Keys = rdr.GetValue("index_keys", string.Empty), + Name = rdr.GetValue("index_name", string.Empty) + }); + } + } + + private void ReadConstraints(IDataReader rdr) + { + rdr.NextResult(); + + while (rdr.Read()) + { + Result.Constraints.Add(new ObjectConstraint + { + DeleteAction = rdr.GetValue("delete_action", string.Empty), + Keys = rdr.GetValue("constraint_keys", string.Empty), + Name = rdr.GetValue("constraint_name", string.Empty), + StatusEnabled = rdr.GetValue("status_enabled", string.Empty), + StatusForReplication = rdr.GetValue("status_for_replication", string.Empty), + Type = rdr.GetValue("constraint_type", string.Empty), + UpdateAction = rdr.GetValue("update_action", string.Empty) + }); + } + } } diff --git a/Connected.Data/Storage/ConnectionProvider.cs b/Connected.Data/Storage/ConnectionProvider.cs index ffc50c4..5644146 100644 --- a/Connected.Data/Storage/ConnectionProvider.cs +++ b/Connected.Data/Storage/ConnectionProvider.cs @@ -1,10 +1,10 @@ -using Connected.Data.Schema; +using System.Collections.Immutable; +using Connected.Data.Schema; using Connected.Data.Sharding; using Connected.Entities.Storage; using Connected.Middleware; using Connected.ServiceModel; using Connected.ServiceModel.Transactions; -using System.Collections.Immutable; namespace Connected.Data.Storage; @@ -26,7 +26,7 @@ internal sealed class ConnectionProvider : IConnectionProvider, IAsyncDisposable public IMiddlewareService Middleware { get; } private ITransactionContext TransactionService { get; } private List Connections => _connections; - public StorageConnectionMode Mode { get; set; } = StorageConnectionMode.Isolated; + public StorageConnectionMode Mode { get; set; } = StorageConnectionMode.Shared; public async Task> Open(StorageContextArgs args) { /* diff --git a/Connected.Data/Update/CommandBuilder.cs b/Connected.Data/Update/CommandBuilder.cs index 1691613..0892801 100644 --- a/Connected.Data/Update/CommandBuilder.cs +++ b/Connected.Data/Update/CommandBuilder.cs @@ -1,216 +1,229 @@ -using Connected.Entities; +using System.Data; +using System.Reflection; +using System.Text; +using Connected.Entities; using Connected.Entities.Annotations; using Connected.Entities.Storage; using Connected.Interop; -using System.Data; -using System.Reflection; -using System.Text; namespace Connected.Data.Update; internal abstract class CommandBuilder { - private readonly List _parameters; - private readonly List _whereProperties; - private List _properties; + private readonly List _parameters; + private readonly List _whereProperties; + private List _properties; - protected CommandBuilder() - { - _parameters = new List(); - _whereProperties = new List(); + protected CommandBuilder() + { + _parameters = new List(); + _whereProperties = new List(); - Text = new StringBuilder(); - } + Text = new StringBuilder(); + } - public StorageOperation? Build(IEntity entity) - { - Entity = entity; + public StorageOperation? Build(IEntity entity) + { + Entity = entity; - if (TryGetExisting(out StorageOperation? existing)) - { - /* + if (TryGetExisting(out StorageOperation? existing)) + { + /* * We need to rebuild an instance since StorageOperation * is immutable */ - var result = new StorageOperation - { - CommandText = existing.CommandText, - CommandTimeout = existing.CommandTimeout, - CommandType = existing.CommandType, - Concurrency = existing.Concurrency - }; - - if (result.Parameters is null) - return result; - - foreach (var parameter in result.Parameters) - { - if (parameter.Direction == ParameterDirection.Input) - { - if (ResolveProperty(parameter.Name) is PropertyInfo property) - { - result.AddParameter(new StorageParameter - { - Value = GetValue(property), - Name = parameter.Name, - Type = parameter.Type, - Direction = parameter.Direction - }); - } - } - } - - return result; - } - - Schema = Entity.GetSchemaAttribute(); - - return OnBuild(); - } - - protected List Properties => _properties ??= GetProperties(); - - protected List WhereProperties => _whereProperties; - - protected string CommandText => Text.ToString(); - - protected IEntity Entity { get; private set; } - - protected SchemaAttribute Schema { get; private set; } - - private StringBuilder Text { get; set; } - - protected abstract StorageOperation OnBuild(); - - protected abstract bool TryGetExisting(out StorageOperation? result); - - protected List Parameters => _parameters; - - protected void Write(string text) - { - Text.Append(text); - } - - protected void Write(char text) - { - Text.Append(text); - } - - protected void WriteLine(string text) - { - Text.AppendLine(text); - } - - protected void Trim() - { - for (var i = Text.Length - 1; i >= 0; i--) - { - if (!Text[i].Equals(',') && !Text[i].Equals('\n') && !Text[i].Equals('\r') && !Text[i].Equals(' ')) - break; - - if (i < Text.Length) - Text.Length = i; - } - } - - protected virtual List GetProperties() - { - return Interop.Properties.GetImplementedProperties(Entity); - } - - protected static bool IsVersion(PropertyInfo property) - { - return property.GetCustomAttribute() is not null; - } - - protected static string ColumnName(PropertyInfo property) - { - var dataMember = property.FindAttribute(); - - return dataMember is null || string.IsNullOrEmpty(dataMember.Member) ? property.Name.ToCamelCase() : dataMember.Member; - } - - protected static DbType ResolveDbType(PropertyInfo property) - { - if (IsVersion(property)) - return DbType.Binary; - - return property.PropertyType.ToDbType(); - } - protected object? GetValue(PropertyInfo property) - { - if (IsNull(property)) - return "NULL"; - - if (IsVersion(property)) - return (byte[])EntityVersion.Parse(property.GetValue(Entity)); - - return GetValue(property.GetValue(Entity), property.PropertyType.ToDbType()); - } - - private static object? GetValue(object value, DbType dbType) - { - switch (dbType) - { - case DbType.Binary: - if (value is byte[] bytes) - return Convert.ToBase64String(bytes); - else - return Convert.ToBase64String(Encoding.UTF8.GetBytes(value.ToString())); - default: - return value; - } - } + var result = new StorageOperation + { + CommandText = existing.CommandText, + CommandTimeout = existing.CommandTimeout, + CommandType = existing.CommandType, + Concurrency = existing.Concurrency + }; + + if (result.Parameters is null) + return result; + + foreach (var parameter in result.Parameters) + { + if (parameter.Direction == ParameterDirection.Input) + { + if (ResolveProperty(parameter.Name) is PropertyInfo property) + { + result.AddParameter(new StorageParameter + { + Value = GetValue(property), + Name = parameter.Name, + Type = parameter.Type, + Direction = parameter.Direction + }); + } + } + } + + return result; + } + + Schema = Entity.GetSchemaAttribute(); + + return OnBuild(); + } + + protected List Properties => _properties ??= GetProperties(); + + protected List WhereProperties => _whereProperties; + + protected string CommandText => Text.ToString(); + + protected IEntity Entity { get; private set; } + + protected SchemaAttribute Schema { get; private set; } + + private StringBuilder Text { get; set; } + + protected abstract StorageOperation OnBuild(); + + protected abstract bool TryGetExisting(out StorageOperation? result); + + protected List Parameters => _parameters; + + protected void Write(string text) + { + Text.Append(text); + } + + protected void Write(char text) + { + Text.Append(text); + } + + protected void WriteLine(string text) + { + Text.AppendLine(text); + } + + protected void Trim() + { + for (var i = Text.Length - 1; i >= 0; i--) + { + if (!Text[i].Equals(',') && !Text[i].Equals('\n') && !Text[i].Equals('\r') && !Text[i].Equals(' ')) + break; + + if (i < Text.Length) + Text.Length = i; + } + } + + protected virtual List GetProperties() + { + var props = Interop.Properties.GetImplementedProperties(Entity); + var result = new List(); + + foreach (var property in props) + { + var persistence = property.FindAttribute(); + + if (persistence is not null && persistence.Persistence.HasFlag(ColumnPersistence.InMemory)) + continue; + + result.Add(property); + } + + return result; + } + + protected static bool IsVersion(PropertyInfo property) + { + return property.GetCustomAttribute() is not null; + } + + protected static string ColumnName(PropertyInfo property) + { + var dataMember = property.FindAttribute(); + + return dataMember is null || string.IsNullOrEmpty(dataMember.Member) ? property.Name.ToCamelCase() : dataMember.Member; + } + + protected static DbType ResolveDbType(PropertyInfo property) + { + if (IsVersion(property)) + return DbType.Binary; - private bool IsNull(PropertyInfo property) - { - var result = property.GetValue(Entity); + return property.PropertyType.ToDbType(); + } + protected object? GetValue(PropertyInfo property) + { + if (IsNull(property)) + return "NULL"; - if (result is null) - return true; + if (IsVersion(property)) + return (byte[])EntityVersion.Parse(property.GetValue(Entity)); - if (property.GetCustomAttribute() is null) - return false; + return GetValue(property.GetValue(Entity), property.PropertyType.ToDbType()); + } - var def = Types.GetDefault(property.PropertyType); - - return TypeComparer.Compare(result, def); - } - - protected StorageParameter CreateParameter(PropertyInfo property) - { - return CreateParameter(property, ParameterDirection.Input); - } - - protected StorageParameter CreateParameter(PropertyInfo property, ParameterDirection direction) - { - var columnName = ColumnName(property); - var parameterName = $"@{columnName}"; - - var parameter = new StorageParameter - { - Direction = direction, - Name = parameterName, - Type = ResolveDbType(property), - Value = GetValue(property) - }; - - Parameters.Add(parameter); - - return parameter; - } - - private PropertyInfo ResolveProperty(string parameterName) - { - var propertyName = parameterName[1..]; - var flags = BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic; - - if (Entity.GetType().GetProperty(propertyName.ToPascalCase(), flags) is PropertyInfo property) - return property; - - if (Entity.GetType().GetProperty(propertyName, flags) is PropertyInfo raw) - return raw; - - return null; - } + private static object? GetValue(object value, DbType dbType) + { + switch (dbType) + { + case DbType.Binary: + if (value is byte[] bytes) + return Convert.ToBase64String(bytes); + else + return Convert.ToBase64String(Encoding.UTF8.GetBytes(value.ToString())); + default: + return value; + } + } + + private bool IsNull(PropertyInfo property) + { + var result = property.GetValue(Entity); + + if (result is null) + return true; + + if (property.GetCustomAttribute() is null) + return false; + + var def = Types.GetDefault(property.PropertyType); + + return TypeComparer.Compare(result, def); + } + + protected StorageParameter CreateParameter(PropertyInfo property) + { + return CreateParameter(property, ParameterDirection.Input); + } + + protected StorageParameter CreateParameter(PropertyInfo property, ParameterDirection direction) + { + var columnName = ColumnName(property); + var parameterName = $"@{columnName}"; + + var parameter = new StorageParameter + { + Direction = direction, + Name = parameterName, + Type = ResolveDbType(property), + Value = GetValue(property) + }; + + Parameters.Add(parameter); + + return parameter; + } + + private PropertyInfo ResolveProperty(string parameterName) + { + var propertyName = parameterName[1..]; + var flags = BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic; + + if (Entity.GetType().GetProperty(propertyName.ToPascalCase(), flags) is PropertyInfo property) + return property; + + if (Entity.GetType().GetProperty(propertyName, flags) is PropertyInfo raw) + return raw; + + return null; + } } diff --git a/Connected.Expressions/Formatters/SqlFormatter.cs b/Connected.Expressions/Formatters/SqlFormatter.cs index 95b4d2c..11abea5 100644 --- a/Connected.Expressions/Formatters/SqlFormatter.cs +++ b/Connected.Expressions/Formatters/SqlFormatter.cs @@ -1,984 +1,990 @@ -using Connected.Expressions.Languages; -using Connected.Expressions.Translation; -using Connected.Expressions.Visitors; -using System.Collections.ObjectModel; +using System.Collections.ObjectModel; using System.Globalization; using System.Linq.Expressions; using System.Reflection; using System.Text; +using Connected.Expressions.Languages; +using Connected.Expressions.Translation; +using Connected.Expressions.Visitors; namespace Connected.Expressions.Formatters; public class SqlFormatter : DatabaseVisitor { - protected const char Space = ' '; - protected const char Period = '.'; - protected const char OpenBracket = '('; - protected const char CloseBracket = ')'; - protected const char SingleQuote = '\''; - protected enum Indentation - { - Same, - Inner, - Outer - } - - protected SqlFormatter(QueryLanguage? language) - { - Language = language; - Text = new StringBuilder(); - Aliases = new(); - } - - private int Depth { get; set; } - protected virtual QueryLanguage? Language { get; } - protected bool HideColumnAliases { get; set; } - protected bool HideTableAliases { get; set; } - protected bool IsNested { get; set; } - public int IndentationWidth { get; set; } = 2; - private StringBuilder Text { get; } - private Dictionary Aliases { get; } - - public static string Format(Expression expression) - { - var formatter = new SqlFormatter(null); - - formatter.Visit(expression); - - return formatter.ToString(); - } - - public override string ToString() - { - return Text.ToString(); - } - - protected void Write(object value) - { - Text.Append(value); - } - - protected virtual void WriteParameterName(string name) - { - Write($"@{name}"); - } - - protected virtual void WriteVariableName(string name) - { - WriteParameterName(name); - } - - protected virtual void WriteAsAliasName(string aliasName) - { - Write("AS "); - WriteAliasName(aliasName); - } - - protected virtual void WriteAliasName(string aliasName) - { - Write(aliasName); - } - - protected virtual void WriteAsColumnName(string columnName) - { - Write("AS "); - WriteColumnName(columnName); - } - - protected virtual void WriteColumnName(string columnName) - { - var name = Language is not null ? Language.Quote(columnName) : columnName; - - Write(name); - } - - protected virtual void WriteTableName(string tableSchema, string tableName) - { - var name = Language is not null ? Language.Quote(tableName) : tableName; - var schema = Language is not null ? Language.Quote(tableSchema) : tableName; - - Write($"{schema}.{name}"); - } - - protected void WriteLine(Indentation style) - { - Text.AppendLine(); - Indent(style); - - for (var i = 0; i < Depth * IndentationWidth; i++) - Write(Space); - } - - protected void Indent(Indentation style) - { - if (style == Indentation.Inner) - Depth++; - else if (style == Indentation.Outer) - Depth--; - } - - protected virtual string GetAliasName(Alias alias) - { - if (!Aliases.TryGetValue(alias, out string? name)) - { - name = $"A{alias.GetHashCode()}?"; - - Aliases.Add(alias, name); - } - - return name; - } - - protected void AddAlias(Alias alias) - { - if (!Aliases.TryGetValue(alias, out _)) - { - var name = $"t{Aliases.Count}"; - - Aliases.Add(alias, name); - } - } - - protected virtual void AddAliases(Expression expr) - { - if (expr as AliasedExpression is AliasedExpression ax) - AddAlias(ax.Alias); - else - { - if (expr as JoinExpression is JoinExpression jx) - { - AddAliases(jx.Left); - AddAliases(jx.Right); - } - } - } - - protected override Expression? Visit(Expression? exp) - { - if (exp is null) - return null; - - return exp.NodeType switch - { - ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Not or ExpressionType.Convert or ExpressionType.ConvertChecked - or ExpressionType.UnaryPlus or ExpressionType.Add or ExpressionType.AddChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked - or ExpressionType.Multiply or ExpressionType.MultiplyChecked or ExpressionType.Divide or ExpressionType.Modulo or ExpressionType.And - or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse or ExpressionType.LessThan or ExpressionType.LessThanOrEqual - or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.Equal or ExpressionType.NotEqual - or ExpressionType.Coalesce or ExpressionType.RightShift or ExpressionType.LeftShift or ExpressionType.ExclusiveOr - or ExpressionType.Power or ExpressionType.Conditional or ExpressionType.Constant or ExpressionType.MemberAccess or ExpressionType.Call - or ExpressionType.New or (ExpressionType)DatabaseExpressionType.Table or (ExpressionType)DatabaseExpressionType.Column - or (ExpressionType)DatabaseExpressionType.Select or (ExpressionType)DatabaseExpressionType.Join or (ExpressionType)DatabaseExpressionType.Aggregate - or (ExpressionType)DatabaseExpressionType.Scalar or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In - or (ExpressionType)DatabaseExpressionType.AggregateSubquery or (ExpressionType)DatabaseExpressionType.IsNull - or (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.RowCount - or (ExpressionType)DatabaseExpressionType.Projection or (ExpressionType)DatabaseExpressionType.NamedValue - or (ExpressionType)DatabaseExpressionType.Block or (ExpressionType)DatabaseExpressionType.If or (ExpressionType)DatabaseExpressionType.Declaration - or (ExpressionType)DatabaseExpressionType.Variable or (ExpressionType)DatabaseExpressionType.Function => base.Visit(exp), - _ => throw new NotSupportedException($"The expression node of type '{DatabaseExpressionExtensions.ResolveNodeTypeName(exp)}' is not supported."), - }; - } - - protected override Expression VisitMemberAccess(MemberExpression m) - { - throw new NotSupportedException($"The member access '{m.Member}' is not supported."); - } - - protected override Expression? VisitMethodCall(MethodCallExpression m) - { - if (m.Method.DeclaringType == typeof(decimal)) - { - switch (m.Method.Name) - { - case "Add": - case "Subtract": - case "Multiply": - case "Divide": - case "Remainder": - Write(OpenBracket); - VisitValue(m.Arguments[0]); - Write(Space); - Write(GetOperator(m.Method.Name)); - Write(Space); - VisitValue(m.Arguments[1]); - Write(CloseBracket); - return m; - case "Negate": - Write('-'); - Visit(m.Arguments[0]); - Write(string.Empty); - return m; - case "Compare": - Visit(Expression.Condition( - Expression.Equal(m.Arguments[0], m.Arguments[1]), - Expression.Constant(0), - Expression.Condition( - Expression.LessThan(m.Arguments[0], m.Arguments[1]), - Expression.Constant(-1), - Expression.Constant(1) - ))); - - return m; - } - } - else if (string.Equals(m.Method.Name, "ToString", StringComparison.Ordinal) && m.Object?.Type == typeof(string)) - { - /* - * no op. - */ - return Visit(m.Object); - } - else if (string.Equals(m.Method.Name, "Equals", StringComparison.Ordinal)) - { - if (m.Method.IsStatic && m.Method.DeclaringType == typeof(object) || m.Method.DeclaringType == typeof(string)) - { - Write(OpenBracket); - Visit(m.Arguments[0]); - Write(" = "); - Visit(m.Arguments[1]); - Write(CloseBracket); - return m; - } - else if (!m.Method.IsStatic && m.Arguments.Count == 1 && m.Arguments[0].Type == m.Object?.Type) - { - Write(OpenBracket); - Visit(m.Object); - Write(" = "); - Visit(m.Arguments[0]); - Write(CloseBracket); - return m; - } - } - - throw new NotSupportedException($"The method '{m.Method.Name}' is not supported"); - } - - protected virtual bool IsInteger(Type type) - { - return Interop.TypeSystem.IsInteger(type); - } - - protected override NewExpression VisitNew(NewExpression nex) - { - throw new NotSupportedException($"The constructor for '{nex.Constructor?.DeclaringType}' is not supported"); - } - - protected override Expression VisitUnary(UnaryExpression u) - { - var op = GetOperator(u); - - switch (u.NodeType) - { - case ExpressionType.Not: - - if (u.Operand is IsNullExpression nullExpression) - { - Visit(nullExpression.Expression); - Write(" IS NOT NULL"); - } - else if (IsBoolean(u.Operand.Type) || op.Length > 1) - { - Write(op); - Write(Space); - VisitPredicate(u.Operand); - } - else - { - Write(op); - VisitValue(u.Operand); - } - - break; - case ExpressionType.Negate: - case ExpressionType.NegateChecked: - Write(op); - VisitValue(u.Operand); - break; - case ExpressionType.UnaryPlus: - VisitValue(u.Operand); - break; - case ExpressionType.Convert: - /* - * ignore conversions for now - */ - Visit(u.Operand); - break; - default: - throw new NotSupportedException($"The unary operator '{u.NodeType}' is not supported"); - } - - return u; - } - - protected override Expression VisitBinary(BinaryExpression b) - { - var op = GetOperator(b); - var left = b.Left; - var right = b.Right; - - Write(OpenBracket); - - switch (b.NodeType) - { - case ExpressionType.And: - case ExpressionType.AndAlso: - case ExpressionType.Or: - case ExpressionType.OrElse: - if (IsBoolean(left.Type)) - { - VisitPredicate(left); - Write(Space); - Write(op); - Write(Space); - VisitPredicate(right); - } - else - { - VisitValue(left); - Write(Space); - Write(op); - Write(Space); - VisitValue(right); - } - break; - case ExpressionType.Equal: - if (right.NodeType == ExpressionType.Constant) - { - var ce = (ConstantExpression)right; - - if (ce.Value is null) - { - Visit(left); - Write(" IS NULL"); - - break; - } - } - else if (left.NodeType == ExpressionType.Constant) - { - var ce = (ConstantExpression)left; - - if (ce.Value is null) - { - Visit(right); - Write(" IS NULL"); - - break; - } - } - goto case ExpressionType.LessThan; - case ExpressionType.NotEqual: - if (right.NodeType == ExpressionType.Constant) - { - var ce = (ConstantExpression)right; - - if (ce.Value is null) - { - Visit(left); - Write(" IS NOT NULL"); - - break; - } - } - else if (left.NodeType == ExpressionType.Constant) - { - var ce = (ConstantExpression)left; - - if (ce.Value is null) - { - Visit(right); - Write(" IS NOT NULL"); - - break; - } - } - goto case ExpressionType.LessThan; - case ExpressionType.LessThan: - case ExpressionType.LessThanOrEqual: - case ExpressionType.GreaterThan: - case ExpressionType.GreaterThanOrEqual: - /* - * check for special x.CompareTo(y) && type.Compare(x,y) - */ - if (left.NodeType == ExpressionType.Call && right.NodeType == ExpressionType.Constant) - { - var mc = (MethodCallExpression)left; - var ce = (ConstantExpression)right; - - if (ce.Value is not null && ce.Value.GetType() == typeof(int) && (int)ce.Value == 0) - { - if (string.Equals(mc.Method.Name, "CompareTo", StringComparison.Ordinal) && !mc.Method.IsStatic && mc.Arguments.Count == 1) - { - left = mc.Object; - right = mc.Arguments[0]; - } - else if ((mc.Method.DeclaringType == typeof(string) || mc.Method.DeclaringType == typeof(decimal)) - && string.Equals(mc.Method.Name, "Compare", StringComparison.Ordinal) && mc.Method.IsStatic && mc.Arguments.Count == 2) - { - left = mc.Arguments[0]; - right = mc.Arguments[1]; - } - } - } - goto case ExpressionType.Add; - case ExpressionType.Add: - case ExpressionType.AddChecked: - case ExpressionType.Subtract: - case ExpressionType.SubtractChecked: - case ExpressionType.Multiply: - case ExpressionType.MultiplyChecked: - case ExpressionType.Divide: - case ExpressionType.Modulo: - case ExpressionType.ExclusiveOr: - case ExpressionType.LeftShift: - case ExpressionType.RightShift: - - if (left is not null) - VisitValue(left); - - Write(Space); - Write(op); - Write(Space); - VisitValue(right); - break; - default: - throw new NotSupportedException($"The binary operator '{b.NodeType}' is not supported"); - } - - Write(CloseBracket); - - return b; - } - - protected virtual string GetOperator(string methodName) - { - return methodName switch - { - "Add" => "+", - "Subtract" => "-", - "Multiply" => "*", - "Divide" => "/", - "Negate" => "-", - "Remainder" => "%", - _ => string.Empty, - }; - } - - protected virtual string GetOperator(UnaryExpression u) - { - return u.NodeType switch - { - ExpressionType.Negate or ExpressionType.NegateChecked => "-", - ExpressionType.UnaryPlus => "+", - ExpressionType.Not => IsBoolean(u.Operand.Type) ? "NOT" : "~", - _ => string.Empty, - }; - } - - protected virtual string GetOperator(BinaryExpression b) - { - return b.NodeType switch - { - ExpressionType.And or ExpressionType.AndAlso => IsBoolean(b.Left.Type) ? "AND" : "&", - ExpressionType.Or or ExpressionType.OrElse => IsBoolean(b.Left.Type) ? "OR" : "|", - ExpressionType.Equal => "=", - ExpressionType.NotEqual => "<>", - ExpressionType.LessThan => "<", - ExpressionType.LessThanOrEqual => "<=", - ExpressionType.GreaterThan => ">", - ExpressionType.GreaterThanOrEqual => ">=", - ExpressionType.Add or ExpressionType.AddChecked => "+", - ExpressionType.Subtract or ExpressionType.SubtractChecked => "-", - ExpressionType.Multiply or ExpressionType.MultiplyChecked => "*", - ExpressionType.Divide => "/", - ExpressionType.Modulo => "%", - ExpressionType.ExclusiveOr => "^", - ExpressionType.LeftShift => "<<", - ExpressionType.RightShift => ">>", - _ => string.Empty, - }; - } - - protected virtual bool IsBoolean(Type type) - { - return type == typeof(bool) || type == typeof(bool?); - } - - protected virtual bool IsPredicate(Expression expr) - { - return expr.NodeType switch - { - ExpressionType.And or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse => IsBoolean(((BinaryExpression)expr).Type), - ExpressionType.Not => IsBoolean(((UnaryExpression)expr).Type), - ExpressionType.Equal or ExpressionType.NotEqual or ExpressionType.LessThan or ExpressionType.LessThanOrEqual or - ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or (ExpressionType)DatabaseExpressionType.IsNull or - (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In => true, - ExpressionType.Call => IsBoolean(((MethodCallExpression)expr).Type), - _ => false, - }; - } - - protected virtual Expression VisitPredicate(Expression expr) - { - Visit(expr); - - if (!IsPredicate(expr)) - Write(" <> 0"); - - return expr; - } - - protected virtual Expression? VisitValue(Expression expr) - { - return Visit(expr); - } - - protected override Expression VisitConditional(ConditionalExpression c) - { - throw new NotSupportedException("Conditional expressions not supported."); - } - - protected override Expression VisitConstant(ConstantExpression c) - { - WriteValue(c.Value); - - return c; - } - - protected virtual void WriteValue(object? value) - { - if (value is null) - Write("NULL"); - else if (value.GetType().GetTypeInfo().IsEnum) - Write(Convert.ChangeType(value, Enum.GetUnderlyingType(value.GetType()))); - else - { - switch (Interop.TypeSystem.GetTypeCode(value.GetType())) - { - case TypeCode.Boolean: - Write((bool)value ? 1 : 0); - break; - case TypeCode.String: - Write(SingleQuote); - Write(value); - Write(SingleQuote); - break; - case TypeCode.Object: - throw new NotSupportedException($"The constant for '{value}' is not supported"); - case TypeCode.Single: - case TypeCode.Double: - var str = ((IConvertible)value).ToString(NumberFormatInfo.InvariantInfo); - - if (!str.Contains(Period)) - str = string.Concat(str, $"{Period}0"); - - Write(str); - break; - default: - Write((value as IConvertible)?.ToString(CultureInfo.InvariantCulture) ?? value); - break; - } - } - } - protected override Expression VisitColumn(ColumnExpression column) - { - if (column.Alias is not null && !HideColumnAliases) - { - WriteAliasName(GetAliasName(column.Alias)); - Write(Period); - } - - WriteColumnName(column.Name); - - return column; - } - protected override Expression VisitProjection(ProjectionExpression proj) - { - // treat these like scalar subqueries - if (proj.Projector is ColumnExpression) - { - Write(OpenBracket); - WriteLine(Indentation.Inner); - Visit(proj.Select); - Write(CloseBracket); - Indent(Indentation.Outer); - } - else - throw new NotSupportedException("Non-scalar projections cannot be translated to SQL."); - - return proj; - } - - protected override Expression VisitSelect(SelectExpression select) - { - AddAliases(select.From); - Write("SELECT "); - - if (select.IsDistinct) - Write("DISTINCT "); - - if (select.Take is not null) - WriteTopClause(select.Take); - - WriteColumns(select.Columns); - - if (select.From is not null) - { - WriteLine(Indentation.Same); - Write("FROM "); - VisitSource(select.From); - } - - if (select.Where is not null) - { - WriteLine(Indentation.Same); - Write("WHERE "); - VisitPredicate(select.Where); - } - - if (select.GroupBy is not null && select.GroupBy.Any()) - { - WriteLine(Indentation.Same); - Write("GROUP BY "); - - for (var i = 0; i < select.GroupBy.Count; i++) - { - if (i > 0) - Write(", "); - - VisitValue(select.GroupBy[i]); - } - } - - if (select.OrderBy is not null && select.OrderBy.Any()) - { - WriteLine(Indentation.Same); - Write("ORDER BY "); - - for (var i = 0; i < select.OrderBy.Count; i++) - { - var exp = select.OrderBy[i]; - - if (i > 0) - Write(", "); - - VisitValue(exp.Expression); - - if (exp.OrderType != OrderType.Ascending) - Write(" DESC"); - } - } - - return select; - } - - protected virtual void WriteTopClause(Expression expression) - { - Write("TOP ("); - Visit(expression); - Write(") "); - } - - protected virtual void WriteColumns(ReadOnlyCollection columns) - { - if (columns.Any()) - { - for (var i = 0; i < columns.Count; i++) - { - var column = columns[i]; - - if (i > 0) - Write(", "); - - var c = VisitValue(column.Expression) as ColumnExpression; - - if (!string.IsNullOrEmpty(column.Name) && (c is null || !string.Equals(c.Name, column.Name, StringComparison.OrdinalIgnoreCase))) - { - Write(Space); - WriteAsColumnName(column.Name); - } - } - } - else - { - Write("NULL "); - - if (IsNested) - { - WriteAsColumnName("tmp"); - Write(Space); - } - } - } - - protected override Expression VisitSource(Expression source) - { - var saveIsNested = IsNested; - - IsNested = true; - - switch ((DatabaseExpressionType)source.NodeType) - { - case DatabaseExpressionType.Table: - var table = (TableExpression)source; - - WriteTableName(table.Schema, table.Name); - - if (!HideTableAliases) - { - Write(Space); - WriteAsAliasName(GetAliasName(table.Alias)); - } - break; - case DatabaseExpressionType.Select: - var select = (SelectExpression)source; - - Write(OpenBracket); - WriteLine(Indentation.Inner); - Visit(select); - WriteLine(Indentation.Same); - Write($"{CloseBracket} "); - WriteAsAliasName(GetAliasName(select.Alias)); - Indent(Indentation.Outer); - break; - case DatabaseExpressionType.Join: - VisitJoin((JoinExpression)source); - break; - default: - throw new InvalidOperationException("Select source is not valid type"); - } - - IsNested = saveIsNested; - - return source; - } - - protected override Expression VisitJoin(JoinExpression join) - { - VisitJoinLeft(join.Left); - WriteLine(Indentation.Same); - - switch (join.Join) - { - case JoinType.CrossJoin: - Write("CROSS JOIN "); - break; - case JoinType.InnerJoin: - Write("INNER JOIN "); - break; - case JoinType.CrossApply: - Write("CROSS APPLY "); - break; - case JoinType.OuterApply: - Write("OUTER APPLY "); - break; - case JoinType.LeftOuter: - case JoinType.SingletonLeftOuter: - Write("LEFT OUTER JOIN "); - break; - } - - VisitJoinRight(join.Right); - - if (join.Condition is not null) - { - WriteLine(Indentation.Inner); - Write("ON "); - VisitPredicate(join.Condition); - Indent(Indentation.Outer); - } - - return join; - } - - protected virtual Expression VisitJoinLeft(Expression source) - { - return VisitSource(source); - } - - protected virtual Expression VisitJoinRight(Expression source) - { - return VisitSource(source); - } - - protected virtual void WriteAggregateName(string aggregateName) - { - switch (aggregateName) - { - case "Average": - Write("AVG"); - break; - case "LongCount": - Write("COUNT"); - break; - default: - Write(aggregateName.ToUpper()); - break; - } - } - - protected virtual bool RequiresAsteriskWhenNoArgument(string aggregateName) - { - return string.Equals(aggregateName, "Count", StringComparison.Ordinal) || string.Equals(aggregateName, "LongCount", StringComparison.Ordinal); - } - - protected override Expression VisitAggregate(AggregateExpression aggregate) - { - WriteAggregateName(aggregate.AggregateName); - Write(OpenBracket); - - if (aggregate.IsDistinct) - Write("DISTINCT "); - - if (aggregate.Argument is not null) - VisitValue(aggregate.Argument); - else if (RequiresAsteriskWhenNoArgument(aggregate.AggregateName)) - Write("*"); - - Write(CloseBracket); - - return aggregate; - } - - protected override Expression VisitIsNull(IsNullExpression isnull) - { - VisitValue(isnull.Expression); - Write(" IS NULL"); - - return isnull; - } - - protected override Expression VisitBetween(BetweenExpression between) - { - VisitValue(between.Expression); - Write(" BETWEEN "); - VisitValue(between.Lower); - Write(" AND "); - VisitValue(between.Upper); - - return between; - } - - protected override Expression VisitRowNumber(RowNumberExpression rowNumber) - { - throw new NotSupportedException(); - } - - protected override Expression VisitScalar(ScalarExpression subquery) - { - Write(OpenBracket); - WriteLine(Indentation.Inner); - Visit(subquery.Select); - WriteLine(Indentation.Same); - Write(CloseBracket); - Indent(Indentation.Outer); - - return subquery; - } - - protected override Expression VisitExists(ExistsExpression exists) - { - Write($"EXISTS{OpenBracket}"); - WriteLine(Indentation.Inner); - Visit(exists.Select); - WriteLine(Indentation.Same); - Write(CloseBracket); - Indent(Indentation.Outer); - - return exists; - } - protected override Expression VisitIn(InExpression @in) - { - if (@in.Values is not null) - { - if (!@in.Values.Any()) - Write("0 <> 0"); - else - { - VisitValue(@in.Expression); - Write($" IN {OpenBracket}"); - - for (var i = 0; i < @in.Values.Count; i++) - { - if (i > 0) - Write(", "); - - VisitValue(@in.Values[i]); - } - - Write(CloseBracket); - } - } - else - { - VisitValue(@in.Expression); - Write($" IN {OpenBracket}"); - WriteLine(Indentation.Inner); - Visit(@in.Select); - WriteLine(Indentation.Same); - Write(CloseBracket); - Indent(Indentation.Outer); - } - - return @in; - } - - protected override Expression VisitNamedValue(NamedValueExpression value) - { - WriteParameterName(value.Name); - - return value; - } - - protected override Expression VisitIf(IfCommandExpression ifx) - { - throw new NotSupportedException(); - } - - protected override Expression VisitBlock(BlockExpression block) - { - throw new NotSupportedException(); - } - - protected override Expression VisitDeclaration(DeclarationExpression declaration) - { - throw new NotSupportedException(nameof(declaration)); - } - - protected override Expression VisitVariable(VariableExpression vex) - { - WriteVariableName(vex.Name); - - return vex; - } - - protected virtual void VisitStatement(Expression expression) - { - if (expression is ProjectionExpression p) - Visit(p.Select); - else - Visit(expression); - } - - protected override Expression VisitFunction(FunctionExpression func) - { - Write(func.Name); - - if (func.Arguments is not null && func.Arguments.Any()) - { - Write(OpenBracket); - - for (var i = 0; i < func.Arguments.Count; i++) - { - if (i > 0) - Write(", "); - - Visit(func.Arguments[i]); - } - - Write(CloseBracket); - } - - return func; - } + protected const char Space = ' '; + protected const char Period = '.'; + protected const char OpenBracket = '('; + protected const char CloseBracket = ')'; + protected const char SingleQuote = '\''; + protected enum Indentation + { + Same, + Inner, + Outer + } + + protected SqlFormatter(QueryLanguage? language) + { + Language = language; + Text = new StringBuilder(); + Aliases = new(); + } + + private int Depth { get; set; } + protected virtual QueryLanguage? Language { get; } + protected bool HideColumnAliases { get; set; } + protected bool HideTableAliases { get; set; } + protected bool UseBracketsInWhere { get; set; } = true; + protected bool IsNested { get; set; } + public int IndentationWidth { get; set; } = 2; + private StringBuilder Text { get; } + private Dictionary Aliases { get; } + + public static string Format(Expression expression) + { + var formatter = new SqlFormatter(null); + + formatter.Visit(expression); + + return formatter.ToString(); + } + + public override string ToString() + { + return Text.ToString(); + } + + protected void Write(object value) + { + Text.Append(value); + } + + protected virtual void WriteParameterName(string name) + { + Write($"@{name}"); + } + + protected virtual void WriteVariableName(string name) + { + WriteParameterName(name); + } + + protected virtual void WriteAsAliasName(string aliasName) + { + Write("AS "); + WriteAliasName(aliasName); + } + + protected virtual void WriteAliasName(string aliasName) + { + Write(aliasName); + } + + protected virtual void WriteAsColumnName(string columnName) + { + Write("AS "); + WriteColumnName(columnName); + } + + protected virtual void WriteColumnName(string columnName) + { + var name = Language is not null ? Language.Quote(columnName) : columnName; + + Write(name); + } + + protected virtual void WriteTableName(string tableSchema, string tableName) + { + var name = Language is not null ? Language.Quote(tableName) : tableName; + var schema = Language is not null ? Language.Quote(tableSchema) : tableName; + + Write($"{schema}.{name}"); + } + + protected void WriteLine(Indentation style) + { + Text.AppendLine(); + Indent(style); + + for (var i = 0; i < Depth * IndentationWidth; i++) + Write(Space); + } + + protected void Indent(Indentation style) + { + if (style == Indentation.Inner) + Depth++; + else if (style == Indentation.Outer) + Depth--; + } + + protected virtual string GetAliasName(Alias alias) + { + if (!Aliases.TryGetValue(alias, out string? name)) + { + name = $"A{alias.GetHashCode()}?"; + + Aliases.Add(alias, name); + } + + return name; + } + + protected void AddAlias(Alias alias) + { + if (!Aliases.TryGetValue(alias, out _)) + { + var name = $"t{Aliases.Count}"; + + Aliases.Add(alias, name); + } + } + + protected virtual void AddAliases(Expression expr) + { + if (expr as AliasedExpression is AliasedExpression ax) + AddAlias(ax.Alias); + else + { + if (expr as JoinExpression is JoinExpression jx) + { + AddAliases(jx.Left); + AddAliases(jx.Right); + } + } + } + + protected override Expression? Visit(Expression? exp) + { + if (exp is null) + return null; + + return exp.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked or ExpressionType.Not or ExpressionType.Convert or ExpressionType.ConvertChecked + or ExpressionType.UnaryPlus or ExpressionType.Add or ExpressionType.AddChecked or ExpressionType.Subtract or ExpressionType.SubtractChecked + or ExpressionType.Multiply or ExpressionType.MultiplyChecked or ExpressionType.Divide or ExpressionType.Modulo or ExpressionType.And + or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.Coalesce or ExpressionType.RightShift or ExpressionType.LeftShift or ExpressionType.ExclusiveOr + or ExpressionType.Power or ExpressionType.Conditional or ExpressionType.Constant or ExpressionType.MemberAccess or ExpressionType.Call + or ExpressionType.New or (ExpressionType)DatabaseExpressionType.Table or (ExpressionType)DatabaseExpressionType.Column + or (ExpressionType)DatabaseExpressionType.Select or (ExpressionType)DatabaseExpressionType.Join or (ExpressionType)DatabaseExpressionType.Aggregate + or (ExpressionType)DatabaseExpressionType.Scalar or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In + or (ExpressionType)DatabaseExpressionType.AggregateSubquery or (ExpressionType)DatabaseExpressionType.IsNull + or (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.RowCount + or (ExpressionType)DatabaseExpressionType.Projection or (ExpressionType)DatabaseExpressionType.NamedValue + or (ExpressionType)DatabaseExpressionType.Block or (ExpressionType)DatabaseExpressionType.If or (ExpressionType)DatabaseExpressionType.Declaration + or (ExpressionType)DatabaseExpressionType.Variable or (ExpressionType)DatabaseExpressionType.Function => base.Visit(exp), + _ => throw new NotSupportedException($"The expression node of type '{DatabaseExpressionExtensions.ResolveNodeTypeName(exp)}' is not supported."), + }; + } + + protected override Expression VisitMemberAccess(MemberExpression m) + { + throw new NotSupportedException($"The member access '{m.Member}' is not supported."); + } + + protected override Expression? VisitMethodCall(MethodCallExpression m) + { + if (m.Method.DeclaringType == typeof(decimal)) + { + switch (m.Method.Name) + { + case "Add": + case "Subtract": + case "Multiply": + case "Divide": + case "Remainder": + Write(OpenBracket); + VisitValue(m.Arguments[0]); + Write(Space); + Write(GetOperator(m.Method.Name)); + Write(Space); + VisitValue(m.Arguments[1]); + Write(CloseBracket); + return m; + case "Negate": + Write('-'); + Visit(m.Arguments[0]); + Write(string.Empty); + return m; + case "Compare": + Visit(Expression.Condition( + Expression.Equal(m.Arguments[0], m.Arguments[1]), + Expression.Constant(0), + Expression.Condition( + Expression.LessThan(m.Arguments[0], m.Arguments[1]), + Expression.Constant(-1), + Expression.Constant(1) + ))); + + return m; + } + } + else if (string.Equals(m.Method.Name, "ToString", StringComparison.Ordinal) && m.Object?.Type == typeof(string)) + { + /* + * no op. + */ + return Visit(m.Object); + } + else if (string.Equals(m.Method.Name, "Equals", StringComparison.Ordinal)) + { + if (m.Method.IsStatic && m.Method.DeclaringType == typeof(object) || m.Method.DeclaringType == typeof(string)) + { + Write(OpenBracket); + Visit(m.Arguments[0]); + Write(" = "); + Visit(m.Arguments[1]); + Write(CloseBracket); + return m; + } + else if (!m.Method.IsStatic && m.Arguments.Count == 1 && m.Arguments[0].Type == m.Object?.Type) + { + Write(OpenBracket); + Visit(m.Object); + Write(" = "); + Visit(m.Arguments[0]); + Write(CloseBracket); + return m; + } + } + + throw new NotSupportedException($"The method '{m.Method.Name}' is not supported"); + } + + protected virtual bool IsInteger(Type type) + { + return Interop.TypeSystem.IsInteger(type); + } + + protected override NewExpression VisitNew(NewExpression nex) + { + throw new NotSupportedException($"The constructor for '{nex.Constructor?.DeclaringType}' is not supported"); + } + + protected override Expression VisitUnary(UnaryExpression u) + { + var op = GetOperator(u); + + switch (u.NodeType) + { + case ExpressionType.Not: + + if (u.Operand is IsNullExpression nullExpression) + { + Visit(nullExpression.Expression); + Write(" IS NOT NULL"); + } + else if (IsBoolean(u.Operand.Type) || op.Length > 1) + { + Write(op); + Write(Space); + VisitPredicate(u.Operand); + } + else + { + Write(op); + VisitValue(u.Operand); + } + + break; + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + Write(op); + VisitValue(u.Operand); + break; + case ExpressionType.UnaryPlus: + VisitValue(u.Operand); + break; + case ExpressionType.Convert: + /* + * ignore conversions for now + */ + Visit(u.Operand); + break; + default: + throw new NotSupportedException($"The unary operator '{u.NodeType}' is not supported"); + } + + return u; + } + + protected override Expression VisitBinary(BinaryExpression b) + { + var op = GetOperator(b); + var left = b.Left; + var right = b.Right; + + if (UseBracketsInWhere) + Write(OpenBracket); + + switch (b.NodeType) + { + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + if (IsBoolean(left.Type)) + { + VisitPredicate(left); + Write(Space); + Write(op); + Write(Space); + VisitPredicate(right); + } + else + { + VisitValue(left); + Write(Space); + Write(op); + Write(Space); + VisitValue(right); + } + break; + case ExpressionType.Equal: + if (right.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)right; + + if (ce.Value is null) + { + Visit(left); + Write(" IS NULL"); + + break; + } + } + else if (left.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)left; + + if (ce.Value is null) + { + Visit(right); + Write(" IS NULL"); + + break; + } + } + goto case ExpressionType.LessThan; + case ExpressionType.NotEqual: + if (right.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)right; + + if (ce.Value is null) + { + Visit(left); + Write(" IS NOT NULL"); + + break; + } + } + else if (left.NodeType == ExpressionType.Constant) + { + var ce = (ConstantExpression)left; + + if (ce.Value is null) + { + Visit(right); + Write(" IS NOT NULL"); + + break; + } + } + goto case ExpressionType.LessThan; + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + /* + * check for special x.CompareTo(y) && type.Compare(x,y) + */ + if (left.NodeType == ExpressionType.Call && right.NodeType == ExpressionType.Constant) + { + var mc = (MethodCallExpression)left; + var ce = (ConstantExpression)right; + + if (ce.Value is not null && ce.Value.GetType() == typeof(int) && (int)ce.Value == 0) + { + if (string.Equals(mc.Method.Name, "CompareTo", StringComparison.Ordinal) && !mc.Method.IsStatic && mc.Arguments.Count == 1) + { + left = mc.Object; + right = mc.Arguments[0]; + } + else if ((mc.Method.DeclaringType == typeof(string) || mc.Method.DeclaringType == typeof(decimal)) + && string.Equals(mc.Method.Name, "Compare", StringComparison.Ordinal) && mc.Method.IsStatic && mc.Arguments.Count == 2) + { + left = mc.Arguments[0]; + right = mc.Arguments[1]; + } + } + } + goto case ExpressionType.Add; + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + + if (left is not null) + VisitValue(left); + + Write(Space); + Write(op); + Write(Space); + VisitValue(right); + break; + default: + throw new NotSupportedException($"The binary operator '{b.NodeType}' is not supported"); + } + + if (UseBracketsInWhere) + Write(CloseBracket); + + return b; + } + + protected virtual string GetOperator(string methodName) + { + return methodName switch + { + "Add" => "+", + "Subtract" => "-", + "Multiply" => "*", + "Divide" => "/", + "Negate" => "-", + "Remainder" => "%", + _ => string.Empty, + }; + } + + protected virtual string GetOperator(UnaryExpression u) + { + return u.NodeType switch + { + ExpressionType.Negate or ExpressionType.NegateChecked => "-", + ExpressionType.UnaryPlus => "+", + ExpressionType.Not => IsBoolean(u.Operand.Type) ? "NOT" : "~", + _ => string.Empty, + }; + } + + protected virtual string GetOperator(BinaryExpression b) + { + return b.NodeType switch + { + ExpressionType.And or ExpressionType.AndAlso => IsBoolean(b.Left.Type) ? "AND" : "&", + ExpressionType.Or or ExpressionType.OrElse => IsBoolean(b.Left.Type) ? "OR" : "|", + ExpressionType.Equal => "=", + ExpressionType.NotEqual => "<>", + ExpressionType.LessThan => "<", + ExpressionType.LessThanOrEqual => "<=", + ExpressionType.GreaterThan => ">", + ExpressionType.GreaterThanOrEqual => ">=", + ExpressionType.Add or ExpressionType.AddChecked => "+", + ExpressionType.Subtract or ExpressionType.SubtractChecked => "-", + ExpressionType.Multiply or ExpressionType.MultiplyChecked => "*", + ExpressionType.Divide => "/", + ExpressionType.Modulo => "%", + ExpressionType.ExclusiveOr => "^", + ExpressionType.LeftShift => "<<", + ExpressionType.RightShift => ">>", + _ => string.Empty, + }; + } + + protected virtual bool IsBoolean(Type type) + { + return type == typeof(bool) || type == typeof(bool?); + } + + protected virtual bool IsPredicate(Expression expr) + { + return expr.NodeType switch + { + ExpressionType.And or ExpressionType.AndAlso or ExpressionType.Or or ExpressionType.OrElse => IsBoolean(((BinaryExpression)expr).Type), + ExpressionType.Not => IsBoolean(((UnaryExpression)expr).Type), + ExpressionType.Equal or ExpressionType.NotEqual or ExpressionType.LessThan or ExpressionType.LessThanOrEqual or + ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or (ExpressionType)DatabaseExpressionType.IsNull or + (ExpressionType)DatabaseExpressionType.Between or (ExpressionType)DatabaseExpressionType.Exists or (ExpressionType)DatabaseExpressionType.In => true, + ExpressionType.Call => IsBoolean(((MethodCallExpression)expr).Type), + _ => false, + }; + } + + protected virtual Expression VisitPredicate(Expression expr) + { + Visit(expr); + + if (!IsPredicate(expr)) + Write(" <> 0"); + + return expr; + } + + protected virtual Expression? VisitValue(Expression expr) + { + return Visit(expr); + } + + protected override Expression VisitConditional(ConditionalExpression c) + { + throw new NotSupportedException("Conditional expressions not supported."); + } + + protected override Expression VisitConstant(ConstantExpression c) + { + WriteValue(c.Value); + + return c; + } + + protected virtual void WriteValue(object? value) + { + if (value is null) + Write("NULL"); + else if (value.GetType().GetTypeInfo().IsEnum) + Write(Convert.ChangeType(value, Enum.GetUnderlyingType(value.GetType()))); + else + { + switch (Interop.TypeSystem.GetTypeCode(value.GetType())) + { + case TypeCode.Boolean: + Write((bool)value ? 1 : 0); + break; + case TypeCode.String: + Write(SingleQuote); + Write(value); + Write(SingleQuote); + break; + case TypeCode.Object: + throw new NotSupportedException($"The constant for '{value}' is not supported"); + case TypeCode.Single: + case TypeCode.Double: + var str = ((IConvertible)value).ToString(NumberFormatInfo.InvariantInfo); + + if (!str.Contains(Period)) + str = string.Concat(str, $"{Period}0"); + + Write(str); + break; + default: + Write((value as IConvertible)?.ToString(CultureInfo.InvariantCulture) ?? value); + break; + } + } + } + protected override Expression VisitColumn(ColumnExpression column) + { + if (column.Alias is not null && !HideColumnAliases) + { + WriteAliasName(GetAliasName(column.Alias)); + Write(Period); + } + + WriteColumnName(column.Name); + + return column; + } + protected override Expression VisitProjection(ProjectionExpression proj) + { + // treat these like scalar subqueries + if (proj.Projector is ColumnExpression) + { + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(proj.Select); + Write(CloseBracket); + Indent(Indentation.Outer); + } + else + throw new NotSupportedException("Non-scalar projections cannot be translated to SQL."); + + return proj; + } + + protected override Expression VisitSelect(SelectExpression select) + { + AddAliases(select.From); + Write("SELECT "); + + if (select.IsDistinct) + Write("DISTINCT "); + + if (select.Take is not null) + WriteTopClause(select.Take); + + WriteColumns(select.Columns); + + if (select.From is not null) + { + WriteLine(Indentation.Same); + Write("FROM "); + VisitSource(select.From); + } + + if (select.Where is not null) + WriteWhere(select.Where); + + if (select.GroupBy is not null && select.GroupBy.Any()) + { + WriteLine(Indentation.Same); + Write("GROUP BY "); + + for (var i = 0; i < select.GroupBy.Count; i++) + { + if (i > 0) + Write(", "); + + VisitValue(select.GroupBy[i]); + } + } + + if (select.OrderBy is not null && select.OrderBy.Any()) + { + WriteLine(Indentation.Same); + Write("ORDER BY "); + + for (var i = 0; i < select.OrderBy.Count; i++) + { + var exp = select.OrderBy[i]; + + if (i > 0) + Write(", "); + + VisitValue(exp.Expression); + + if (exp.OrderType != OrderType.Ascending) + Write(" DESC"); + } + } + + return select; + } + + protected virtual void WriteWhere(Expression expression) + { + WriteLine(Indentation.Same); + Write("WHERE "); + VisitPredicate(expression); + } + + protected virtual void WriteTopClause(Expression expression) + { + Write("TOP ("); + Visit(expression); + Write(") "); + } + + protected virtual void WriteColumns(ReadOnlyCollection columns) + { + if (columns.Any()) + { + for (var i = 0; i < columns.Count; i++) + { + var column = columns[i]; + + if (i > 0) + Write(", "); + + var c = VisitValue(column.Expression) as ColumnExpression; + + if (!string.IsNullOrEmpty(column.Name) && (c is null || !string.Equals(c.Name, column.Name, StringComparison.OrdinalIgnoreCase))) + { + Write(Space); + WriteAsColumnName(column.Name); + } + } + } + else + { + Write("NULL "); + + if (IsNested) + { + WriteAsColumnName("tmp"); + Write(Space); + } + } + } + + protected override Expression VisitSource(Expression source) + { + var saveIsNested = IsNested; + + IsNested = true; + + switch ((DatabaseExpressionType)source.NodeType) + { + case DatabaseExpressionType.Table: + var table = (TableExpression)source; + + WriteTableName(table.Schema, table.Name); + + if (!HideTableAliases) + { + Write(Space); + WriteAsAliasName(GetAliasName(table.Alias)); + } + break; + case DatabaseExpressionType.Select: + var select = (SelectExpression)source; + + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(select); + WriteLine(Indentation.Same); + Write($"{CloseBracket} "); + WriteAsAliasName(GetAliasName(select.Alias)); + Indent(Indentation.Outer); + break; + case DatabaseExpressionType.Join: + VisitJoin((JoinExpression)source); + break; + default: + throw new InvalidOperationException("Select source is not valid type"); + } + + IsNested = saveIsNested; + + return source; + } + + protected override Expression VisitJoin(JoinExpression join) + { + VisitJoinLeft(join.Left); + WriteLine(Indentation.Same); + + switch (join.Join) + { + case JoinType.CrossJoin: + Write("CROSS JOIN "); + break; + case JoinType.InnerJoin: + Write("INNER JOIN "); + break; + case JoinType.CrossApply: + Write("CROSS APPLY "); + break; + case JoinType.OuterApply: + Write("OUTER APPLY "); + break; + case JoinType.LeftOuter: + case JoinType.SingletonLeftOuter: + Write("LEFT OUTER JOIN "); + break; + } + + VisitJoinRight(join.Right); + + if (join.Condition is not null) + { + WriteLine(Indentation.Inner); + Write("ON "); + VisitPredicate(join.Condition); + Indent(Indentation.Outer); + } + + return join; + } + + protected virtual Expression VisitJoinLeft(Expression source) + { + return VisitSource(source); + } + + protected virtual Expression VisitJoinRight(Expression source) + { + return VisitSource(source); + } + + protected virtual void WriteAggregateName(string aggregateName) + { + switch (aggregateName) + { + case "Average": + Write("AVG"); + break; + case "LongCount": + Write("COUNT"); + break; + default: + Write(aggregateName.ToUpper()); + break; + } + } + + protected virtual bool RequiresAsteriskWhenNoArgument(string aggregateName) + { + return string.Equals(aggregateName, "Count", StringComparison.Ordinal) || string.Equals(aggregateName, "LongCount", StringComparison.Ordinal); + } + + protected override Expression VisitAggregate(AggregateExpression aggregate) + { + WriteAggregateName(aggregate.AggregateName); + Write(OpenBracket); + + if (aggregate.IsDistinct) + Write("DISTINCT "); + + if (aggregate.Argument is not null) + VisitValue(aggregate.Argument); + else if (RequiresAsteriskWhenNoArgument(aggregate.AggregateName)) + Write("*"); + + Write(CloseBracket); + + return aggregate; + } + + protected override Expression VisitIsNull(IsNullExpression isnull) + { + VisitValue(isnull.Expression); + Write(" IS NULL"); + + return isnull; + } + + protected override Expression VisitBetween(BetweenExpression between) + { + VisitValue(between.Expression); + Write(" BETWEEN "); + VisitValue(between.Lower); + Write(" AND "); + VisitValue(between.Upper); + + return between; + } + + protected override Expression VisitRowNumber(RowNumberExpression rowNumber) + { + throw new NotSupportedException(); + } + + protected override Expression VisitScalar(ScalarExpression subquery) + { + Write(OpenBracket); + WriteLine(Indentation.Inner); + Visit(subquery.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + + return subquery; + } + + protected override Expression VisitExists(ExistsExpression exists) + { + Write($"EXISTS{OpenBracket}"); + WriteLine(Indentation.Inner); + Visit(exists.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + + return exists; + } + protected override Expression VisitIn(InExpression @in) + { + if (@in.Values is not null) + { + if (!@in.Values.Any()) + Write("0 <> 0"); + else + { + VisitValue(@in.Expression); + Write($" IN {OpenBracket}"); + + for (var i = 0; i < @in.Values.Count; i++) + { + if (i > 0) + Write(", "); + + VisitValue(@in.Values[i]); + } + + Write(CloseBracket); + } + } + else + { + VisitValue(@in.Expression); + Write($" IN {OpenBracket}"); + WriteLine(Indentation.Inner); + Visit(@in.Select); + WriteLine(Indentation.Same); + Write(CloseBracket); + Indent(Indentation.Outer); + } + + return @in; + } + + protected override Expression VisitNamedValue(NamedValueExpression value) + { + WriteParameterName(value.Name); + + return value; + } + + protected override Expression VisitIf(IfCommandExpression ifx) + { + throw new NotSupportedException(); + } + + protected override Expression VisitBlock(BlockExpression block) + { + throw new NotSupportedException(); + } + + protected override Expression VisitDeclaration(DeclarationExpression declaration) + { + throw new NotSupportedException(nameof(declaration)); + } + + protected override Expression VisitVariable(VariableExpression vex) + { + WriteVariableName(vex.Name); + + return vex; + } + + protected virtual void VisitStatement(Expression expression) + { + if (expression is ProjectionExpression p) + Visit(p.Select); + else + Visit(expression); + } + + protected override Expression VisitFunction(FunctionExpression func) + { + Write(func.Name); + + if (func.Arguments is not null && func.Arguments.Any()) + { + Write(OpenBracket); + + for (var i = 0; i < func.Arguments.Count; i++) + { + if (i > 0) + Write(", "); + + Visit(func.Arguments[i]); + } + + Write(CloseBracket); + } + + return func; + } } \ No newline at end of file diff --git a/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs b/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs index 8f33eed..18b80b4 100644 --- a/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs +++ b/Connected.Expressions/Translation/Rewriters/WhereClauseRewriter.cs @@ -1,24 +1,24 @@ -using Connected.Expressions.Visitors; -using System.Linq.Expressions; +using System.Linq.Expressions; +using Connected.Expressions.Visitors; namespace Connected.Expressions.Translation.Rewriters; public class WhereClauseRewriter : DatabaseVisitor { - private WhereClauseRewriter(ExpressionCompilationContext context) - { - Context = context; - } + protected WhereClauseRewriter(ExpressionCompilationContext context) + { + Context = context; + } - public ExpressionCompilationContext Context { get; } + public ExpressionCompilationContext Context { get; } - public static Expression Rewrite(ExpressionCompilationContext context, Expression expression) - { - return new WhereClauseRewriter(context).Visit(expression); - } + public static Expression Rewrite(ExpressionCompilationContext context, Expression expression) + { + return new WhereClauseRewriter(context).Visit(expression); + } - protected override Expression VisitWhere(Expression whereExpression) - { - return ParameterRewriter.Rewrite(Context, whereExpression); - } + protected override Expression VisitWhere(Expression whereExpression) + { + return ParameterRewriter.Rewrite(Context, whereExpression); + } } diff --git a/Connected.Instance/EntitySynchronizer.cs b/Connected.Instance/EntitySynchronizer.cs index d41c4a3..f722d70 100644 --- a/Connected.Instance/EntitySynchronizer.cs +++ b/Connected.Instance/EntitySynchronizer.cs @@ -46,7 +46,7 @@ internal static class EntitySynchronizer else { /* - * entitySynchronization = token1, token2,... + * entitySynchronization = token1; token2,... */ var tokens = value.Split(','); /* @@ -66,13 +66,20 @@ internal static class EntitySynchronizer await Synchronize(schemaService, Assembly.Load(AssemblyName.GetAssemblyName(subTokens[1])), logger); } - else if (string.Equals(subTokens[1], "type", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(subTokens[0], "type", StringComparison.OrdinalIgnoreCase)) { - logger.LogTrace("Loading type '{type}'", subTokens[1]); + var typeTokens = subTokens[1].Split("/"); - if (Type.GetType(subTokens[1]) is not Type type) + if (typeTokens.Length != 2) + throw new ArgumentException("Invalid entitySyncronization token '{token}'. Expected type:[assembly]/[type].", subTokens[1]); + + var qualifier = $"{typeTokens[1]}, {typeTokens[0]}"; + + logger.LogTrace("Loading type '{type}'", qualifier); + + if (Type.GetType(qualifier) is not Type type) { - logger.LogWarning("Entity type '{type}' could not be loaded. Synchronization on the specified type could not be performed.", subTokens[1]); + logger.LogWarning("Entity type '{type}' could not be loaded. Synchronization on the specified type could not be performed.", qualifier); continue; } diff --git a/Connected.Interop/Generics.cs b/Connected.Interop/Generics.cs index 0b8d41b..018db59 100644 --- a/Connected.Interop/Generics.cs +++ b/Connected.Interop/Generics.cs @@ -18,5 +18,33 @@ return false; } + + public static bool ImplementsInterface(this Type type, Type interfaceType) + { + var interfaces = type.GetInterfaces(); + + foreach (var i in interfaces) + { + if (!i.IsGenericType) + { + if (i == interfaceType) + return true; + + continue; + } + + var definition = i.GetGenericTypeDefinition(); + + if (definition == interfaceType) + return true; + } + + return false; + } + + public static bool ImplementsInterface(this Type type) + { + return ImplementsInterface(type, typeof(TInterface)); + } } } diff --git a/Connected.Interop/Merging/JsonMerger.cs b/Connected.Interop/Merging/JsonMerger.cs index 51899d0..db925f4 100644 --- a/Connected.Interop/Merging/JsonMerger.cs +++ b/Connected.Interop/Merging/JsonMerger.cs @@ -63,7 +63,7 @@ namespace Connected.Interop.Merging var addMethod = property.PropertyType.GetMethod(nameof(IList.Add), BindingFlags.Public | BindingFlags.Instance | BindingFlags.NonPublic); var instance = addMethod is not null ? Activator.CreateInstance(property.PropertyType) : Array.CreateInstance(property.PropertyType, array.Count); - var elementType = instance.GetType().GetElementType(); + var elementType = property.PropertyType.GenericTypeArguments[0]; for (var i = 0; i < array.Count; i++) { diff --git a/Connected.Interop/Serializer.cs b/Connected.Interop/Serializer.cs index e4732b3..7a166be 100644 --- a/Connected.Interop/Serializer.cs +++ b/Connected.Interop/Serializer.cs @@ -1,9 +1,9 @@ -using Connected.Interop.Merging; -using System.Collections; +using System.Collections; using System.Reflection; using System.Text; using System.Text.Json; using System.Text.Json.Nodes; +using Connected.Interop.Merging; namespace Connected.Interop; @@ -41,12 +41,17 @@ public static class Serializer using var ms = new MemoryStream(); using var writer = new Utf8JsonWriter(ms, new JsonWriterOptions { Indented = true, SkipValidation = false }); - if (value.GetType().IsEnumerable()) - SerializeArray(writer, null, value); - else if (value.GetType().IsTypePrimitive()) - SerializePrimitive(writer, value); + if (value is JsonNode json) + json.WriteTo(writer); else - SerializeObject(writer, null, value); + { + if (value.GetType().IsEnumerable()) + SerializeArray(writer, null, value); + else if (value.GetType().IsTypePrimitive()) + SerializePrimitive(writer, value); + else + SerializeObject(writer, null, value); + } await writer.FlushAsync(); @@ -78,18 +83,20 @@ public static class Serializer if (value is null) return; + var enumerable = property.GetValue(value) as IEnumerable; + + if (enumerable is null) + return; + if (property is not null) writer.WriteStartArray(property.Name.ToCamelCase()); else writer.WriteStartArray(); - if (value is IEnumerable enumerable) - { - var enumerator = enumerable.GetEnumerator(); + var enumerator = enumerable.GetEnumerator(); - while (enumerator.MoveNext()) - SerializeObject(writer, null, enumerator.Current); - } + while (enumerator.MoveNext()) + SerializeObject(writer, null, enumerator.Current); writer.WriteEndArray(); } @@ -100,10 +107,12 @@ public static class Serializer SerializeArray(writer, property, value); else if (!property.PropertyType.IsTypePrimitive()) SerializeObject(writer, property, value); + else + { + writer.WritePropertyName(property.Name.ToCamelCase()); - writer.WritePropertyName(property.Name.ToCamelCase()); - - SerializePrimitive(writer, property.GetValue(value)); + SerializePrimitive(writer, property.GetValue(value)); + } } private static void SerializePrimitive(Utf8JsonWriter writer, object? value) diff --git a/Connected.Interop/TypeSystem.cs b/Connected.Interop/TypeSystem.cs index 9ba736a..edc2de9 100644 --- a/Connected.Interop/TypeSystem.cs +++ b/Connected.Interop/TypeSystem.cs @@ -96,8 +96,10 @@ namespace Connected.Interop return DbType.Byte; else if (underlyingType == typeof(bool)) return DbType.Boolean; - else if (underlyingType == typeof(DateTime) || underlyingType == typeof(DateTimeOffset)) + else if (underlyingType == typeof(DateTime)) return DbType.DateTime2; + else if (underlyingType == typeof(DateTimeOffset)) + return DbType.DateTimeOffset; else if (underlyingType == typeof(decimal)) return DbType.Decimal; else if (underlyingType == typeof(double)) diff --git a/Connected.Net/HttpExtensions.cs b/Connected.Net/HttpExtensions.cs index 1f39a63..c781041 100644 --- a/Connected.Net/HttpExtensions.cs +++ b/Connected.Net/HttpExtensions.cs @@ -68,7 +68,11 @@ public static class HttpExtensions } public static async Task Post(this IHttpService factory, string? requestUri, object? content, CancellationToken cancellationToken = default) { - return await HandleResponse(await factory.CreateClient().SendAsync(await CreatePostMessage(requestUri, content), cancellationToken)); + var client = factory.CreateClient(); + var message = await CreatePostMessage(requestUri, content); + var response = await client.SendAsync(message, cancellationToken); + + return await HandleResponse(response); } private static HttpRequestMessage CreateGetMessage(string? requestUri) {