using Connected.Expressions.Reflection; using Connected.Interop; using System.Linq.Expressions; using System.Reflection; using ExpressionVisitor = Connected.Expressions.Visitors.ExpressionVisitor; namespace Connected.Expressions.Evaluation; internal sealed class SubtreeEvaluator : ExpressionVisitor { private SubtreeEvaluator(ExpressionCompilationContext context, HashSet candidates, Func? onEval) { Candidates = candidates; OnEval = onEval; Context = context; } public ExpressionCompilationContext Context { get; } private HashSet Candidates { get; set; } private Func? OnEval { get; set; } internal static Expression Eval(ExpressionCompilationContext context, HashSet candidates, Func? onEval, Expression exp) { if (new SubtreeEvaluator(context, candidates, onEval).Visit(exp) is not Expression subtreeExpression) throw new NullReferenceException(nameof(subtreeExpression)); return subtreeExpression; } protected override Expression? Visit(Expression? exp) { if (exp is null) return null; if (Candidates.Contains(exp)) return Evaluate(exp); return base.Visit(exp); } protected override Expression VisitConditional(ConditionalExpression c) { if (Candidates.Contains(c.Test)) { var test = Evaluate(c.Test); if (test is ConstantExpression && ((ConstantExpression)test).Type == typeof(bool)) { if ((bool)((ConstantExpression)test).Value) return Visit(c.IfTrue); else return Visit(c.IfFalse); } } return base.VisitConditional(c); } private Expression PostEval(ConstantExpression e) { if (OnEval is not null) return OnEval(e); return e; } private Expression Evaluate(Expression e) { var type = e.Type; if (e.NodeType == ExpressionType.Convert) { var u = (UnaryExpression)e; if (Nullables.GetNonNullableType(u.Operand.Type) == Nullables.GetNonNullableType(type)) e = ((UnaryExpression)e).Operand; } if (e.NodeType == ExpressionType.Constant) { if (e.Type == type) return e; else if (Nullables.GetNonNullableType(e.Type) == Nullables.GetNonNullableType(type)) return Expression.Constant(((ConstantExpression)e).Value, type); } if (e is MemberExpression me) { if (me.Expression is ConstantExpression ce) { var value = me.Member.GetValue(ce.Value); var constant = Expression.Constant(value, type); Context.Parameters.Add(me.Member.Name, constant); return PostEval(constant); } } if (type.GetTypeInfo().IsValueType) e = Expression.Convert(e, typeof(object)); var lambda = Expression.Lambda>(e); var fn = lambda.Compile(); return PostEval(Expression.Constant(fn(), type)); } }