using System.Linq.Expressions; using Connected.Expressions.Evaluation; using Connected.Expressions.Expressions; using Connected.Expressions.Languages; using Connected.Expressions.Visitors; namespace Connected.Expressions.Translation.Projections; internal enum ProjectionAffinity { Client, Server } internal sealed class ColumnProjector : DatabaseVisitor { private readonly QueryLanguage _language; private readonly Dictionary _map; private readonly List _columns; private readonly HashSet _columnNames; private readonly HashSet _candidates; private readonly HashSet _existingAliases; private readonly Alias _newAlias; private ColumnProjector(QueryLanguage language, ProjectionAffinity affinity, Expression expression, IEnumerable existingColumns, Alias newAlias, IEnumerable existingAliases) { _language = language; _newAlias = newAlias; _existingAliases = new HashSet(existingAliases); _map = new Dictionary(); if (existingColumns is not null) { _columns = new List(existingColumns); _columnNames = new HashSet(existingColumns.Select(c => c.Name)); } else { _columns = new List(); _columnNames = new HashSet(); } _candidates = ExpressionNominator.Nominate(Language, affinity, expression); } private QueryLanguage Language => _language; private Dictionary Map => _map; private List Columns => _columns; private HashSet ColumnNames => _columnNames; private HashSet Candidates => _candidates; private HashSet ExistingAliases => _existingAliases; private Alias NewAlias => _newAlias; private int ColumnCounter { get; set; } public static ProjectedColumns ProjectColumns(QueryLanguage language, ProjectionAffinity affinity, Expression expression, IEnumerable? existingColumns, Alias newAlias, IEnumerable existingAliases) { var projector = new ColumnProjector(language, affinity, expression, existingColumns, newAlias, existingAliases); var expr = projector.Visit(expression); return new ProjectedColumns(expr, projector.Columns.AsReadOnly()); } public static ProjectedColumns ProjectColumns(QueryLanguage language, Expression expression, IEnumerable? existingColumns, Alias newAlias, IEnumerable existingAliases) { return ProjectColumns(language, ProjectionAffinity.Client, expression, existingColumns, newAlias, existingAliases); } public static ProjectedColumns ProjectColumns(QueryLanguage language, ProjectionAffinity affinity, Expression expression, IEnumerable existingColumns, Alias newAlias, params Alias[] existingAliases) { return ProjectColumns(language, affinity, expression, existingColumns, newAlias, (IEnumerable)existingAliases); } public static ProjectedColumns ProjectColumns(QueryLanguage language, Expression expression, IEnumerable existingColumns, Alias newAlias, params Alias[] existingAliases) { return ProjectColumns(language, expression, existingColumns, newAlias, (IEnumerable)existingAliases); } protected override Expression? Visit(Expression? expression) { if (Candidates.Contains(expression)) { if (expression.NodeType == (ExpressionType)DatabaseExpressionType.Column) { var column = (ColumnExpression)expression; if (Map.TryGetValue(column, out ColumnExpression? mapped)) return mapped; foreach (ColumnDeclaration existingColumn in Columns) { if (existingColumn.Expression is ColumnExpression cex && cex.Alias == column.Alias && cex.Name == column.Name) return new ColumnExpression(column.Type, column.QueryType, NewAlias, existingColumn.Name); } if (ExistingAliases.Contains(column.Alias)) { var ordinal = Columns.Count; var columnName = GetUniqueColumnName(column.Name); Columns.Add(new ColumnDeclaration(columnName, column, column.QueryType)); mapped = new ColumnExpression(column.Type, column.QueryType, NewAlias, columnName); Map.Add(column, mapped); ColumnNames.Add(columnName); return mapped; } return column; } else { var columnName = GetNextColumnName(); var colType = Language.TypeSystem.ResolveColumnType(expression.Type); Columns.Add(new ColumnDeclaration(columnName, expression, colType)); return new ColumnExpression(expression.Type, colType, NewAlias, columnName); } } else return base.Visit(expression); } private bool IsColumnNameInUse(string name) => ColumnNames.Contains(name); private string GetUniqueColumnName(string name) { var baseName = name; var suffix = 1; while (IsColumnNameInUse(name)) name = baseName + suffix++; return name; } private string GetNextColumnName() => GetUniqueColumnName($"c{ColumnCounter++}"); }