You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
136 lines
4.5 KiB
136 lines
4.5 KiB
using Connected.Expressions.Expressions;
|
|
using Connected.Expressions.Mappings;
|
|
using Connected.Expressions.Visitors;
|
|
using System.Linq.Expressions;
|
|
using System.Reflection;
|
|
|
|
namespace Connected.Expressions.Translation.Rewriters;
|
|
|
|
public sealed class ComparisonRewriter : DatabaseVisitor
|
|
{
|
|
private ComparisonRewriter()
|
|
{
|
|
}
|
|
|
|
public static Expression Rewrite(Expression expression)
|
|
{
|
|
return new ComparisonRewriter().Visit(expression);
|
|
}
|
|
|
|
protected override Expression VisitBinary(BinaryExpression b)
|
|
{
|
|
switch (b.NodeType)
|
|
{
|
|
case ExpressionType.Equal:
|
|
case ExpressionType.NotEqual:
|
|
var result = Compare(b);
|
|
|
|
if (result == b)
|
|
goto default;
|
|
|
|
return Visit(result);
|
|
default:
|
|
return base.VisitBinary(b);
|
|
}
|
|
}
|
|
|
|
private static Expression SkipConvert(Expression expression)
|
|
{
|
|
while (expression.NodeType == ExpressionType.Convert)
|
|
expression = ((UnaryExpression)expression).Operand;
|
|
|
|
return expression;
|
|
}
|
|
|
|
private Expression Compare(BinaryExpression bop)
|
|
{
|
|
var e1 = SkipConvert(bop.Left);
|
|
var e2 = SkipConvert(bop.Right);
|
|
var oj1 = e1 as OuterJoinedExpression;
|
|
var oj2 = e2 as OuterJoinedExpression;
|
|
var entity1 = oj1 is not null ? oj1.Expression as EntityExpression : e1 as EntityExpression;
|
|
var entity2 = oj2 is not null ? oj2.Expression as EntityExpression : e2 as EntityExpression;
|
|
var negate = bop.NodeType == ExpressionType.NotEqual;
|
|
|
|
if (oj1 is not null && e2.NodeType == ExpressionType.Constant && ((ConstantExpression)e2).Value is null)
|
|
return MakeIsNull(oj1.Test, negate);
|
|
else if (oj2 is not null && e1.NodeType == ExpressionType.Constant && ((ConstantExpression)e1).Value is null)
|
|
return MakeIsNull(oj2.Test, negate);
|
|
|
|
if (entity1 is not null)
|
|
return MakePredicate(e1, e2, MappingsCache.Get(entity1.EntityType).Members.Where(f => f.IsPrimaryKey).Select(f => f.MemberInfo), negate);
|
|
else if (entity2 is not null)
|
|
return MakePredicate(e1, e2, MappingsCache.Get(entity2.EntityType).Members.Where(f => f.IsPrimaryKey).Select(f => f.MemberInfo), negate);
|
|
|
|
var dm1 = GetDefinedMembers(e1);
|
|
var dm2 = GetDefinedMembers(e2);
|
|
|
|
if (dm1 is null && dm2 is null)
|
|
return bop;
|
|
|
|
if (dm1 is not null && dm2 is not null)
|
|
{
|
|
var names1 = new HashSet<string>(dm1.Select(m => m.Name));
|
|
var names2 = new HashSet<string>(dm2.Select(m => m.Name));
|
|
|
|
if (names1.IsSubsetOf(names2) && names2.IsSubsetOf(names1))
|
|
return MakePredicate(e1, e2, dm1, negate);
|
|
}
|
|
else if (dm1 is not null)
|
|
return MakePredicate(e1, e2, dm1, negate);
|
|
else if (dm2 is not null)
|
|
return MakePredicate(e1, e2, dm2, negate);
|
|
|
|
throw new InvalidOperationException("Cannot compare two constructed types with different sets of members assigned.");
|
|
}
|
|
|
|
private static Expression MakeIsNull(Expression expression, bool negate)
|
|
{
|
|
var isnull = new IsNullExpression(expression);
|
|
|
|
return negate ? Expression.Not(isnull) : isnull;
|
|
}
|
|
|
|
private static Expression? MakePredicate(Expression e1, Expression e2, IEnumerable<MemberInfo> members, bool negate)
|
|
{
|
|
var pred = members.Select(m => Binder.Bind(e1, m).Equal(Binder.Bind(e2, m))).Join(ExpressionType.And);
|
|
|
|
if (negate)
|
|
pred = Expression.Not(pred);
|
|
|
|
return pred;
|
|
}
|
|
|
|
private static IEnumerable<MemberInfo> GetDefinedMembers(Expression expr)
|
|
{
|
|
var mini = expr as MemberInitExpression;
|
|
|
|
if (mini is not null)
|
|
{
|
|
var members = mini.Bindings.Select(b => FixMember(b.Member));
|
|
|
|
if (mini.NewExpression.Members is not null)
|
|
members = members.Concat(mini.NewExpression.Members.Select(FixMember));
|
|
|
|
return members;
|
|
}
|
|
else
|
|
{
|
|
var nex = expr as NewExpression;
|
|
|
|
if (nex is not null && nex.Members is not null)
|
|
return nex.Members.Select(FixMember);
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
private static MemberInfo FixMember(MemberInfo member)
|
|
{
|
|
if (member is MethodInfo && member.Name.StartsWith("get_"))
|
|
return member.DeclaringType.GetTypeInfo().GetDeclaredProperty(member.Name[4..]);
|
|
|
|
return member;
|
|
}
|
|
}
|