using Connected.Expressions.Expressions; using Connected.Expressions.Translation.Resolvers; using Connected.Expressions.Visitors; using System.Linq.Expressions; namespace Connected.Expressions.Translation.Rewriters; public sealed class AggregateRewriter : DatabaseVisitor { private AggregateRewriter(ExpressionCompilationContext context, Expression expr) { Context = context; Map = new(); Lookup = AggregateResolver.Resolve(expr).ToLookup(a => a.GroupByAlias); } private ExpressionCompilationContext Context { get; set; } private ILookup Lookup { get; set; } private Dictionary Map { get; set; } public static Expression Rewrite(ExpressionCompilationContext context, Expression expr) { if (new AggregateRewriter(context, expr).Visit(expr) is not Expression aggregateExpression) throw new NullReferenceException(nameof(aggregateExpression)); return aggregateExpression; } protected override Expression VisitSelect(SelectExpression select) { select = (SelectExpression)base.VisitSelect(select); if (Lookup.Contains(select.Alias)) { var aggColumns = new List(select.Columns); foreach (AggregateSubqueryExpression ae in Lookup[select.Alias]) { var name = $"agg{aggColumns.Count}"; var colType = Context.Language.TypeSystem.ResolveColumnType(ae.Type); var cd = new ColumnDeclaration(name, ae.AggregateInGroupSelect, colType); Map.Add(ae, new ColumnExpression(ae.Type, colType, ae.GroupByAlias, name)); aggColumns.Add(cd); } return new SelectExpression(select.Alias, aggColumns, select.From, select.Where, select.OrderBy, select.GroupBy, select.IsDistinct, select.Skip, select.Take, select.IsReverse); } return select; } protected override Expression VisitAggregateSubquery(AggregateSubqueryExpression expression) { if (Map.TryGetValue(expression, out Expression mapped)) return mapped; return Visit(expression.AggregateAsSubquery); } }