You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
160 lines
5.0 KiB
160 lines
5.0 KiB
using System.Linq.Expressions;
|
|
using Connected.Expressions.Expressions;
|
|
using Connected.Expressions.Languages;
|
|
using Connected.Expressions.Visitors;
|
|
|
|
namespace Connected.Expressions.Translation;
|
|
|
|
internal sealed class Parameterizer : DatabaseVisitor
|
|
{
|
|
private Parameterizer(QueryLanguage language)
|
|
{
|
|
Language = language;
|
|
Map = new();
|
|
ParameterMap = new();
|
|
}
|
|
|
|
private QueryLanguage Language { get; }
|
|
private Dictionary<TypeValue, NamedValueExpression> Map { get; }
|
|
private Dictionary<HashedExpression, NamedValueExpression> ParameterMap { get; }
|
|
private int Counter { get; set; }
|
|
|
|
public static Expression Parameterize(QueryLanguage language, Expression expression)
|
|
{
|
|
if (new Parameterizer(language).Visit(expression) is not Expression parameterizedExpression)
|
|
throw new NullReferenceException(nameof(parameterizedExpression));
|
|
|
|
return parameterizedExpression;
|
|
}
|
|
|
|
protected override Expression VisitProjection(ProjectionExpression expression)
|
|
{
|
|
if (Visit(expression.Select) is not SelectExpression selectExpression)
|
|
throw new NullReferenceException(nameof(selectExpression));
|
|
|
|
return UpdateProjection(expression, selectExpression, expression.Projector, expression.Aggregator);
|
|
}
|
|
|
|
protected override Expression VisitUnary(UnaryExpression expression)
|
|
{
|
|
if (expression.NodeType == ExpressionType.Convert && expression.Operand.NodeType == ExpressionType.ArrayIndex)
|
|
{
|
|
var b = (BinaryExpression)expression.Operand;
|
|
|
|
if (IsConstantOrParameter(b.Left) && IsConstantOrParameter(b.Right))
|
|
return GetNamedValue(expression);
|
|
}
|
|
|
|
return base.VisitUnary(expression);
|
|
}
|
|
|
|
private static bool IsConstantOrParameter(Expression expression)
|
|
{
|
|
return expression is not null && (expression.NodeType == ExpressionType.Constant || expression.NodeType == ExpressionType.Parameter);
|
|
}
|
|
|
|
protected override Expression VisitBinary(BinaryExpression expression)
|
|
{
|
|
if (Visit(expression.Left) is not Expression leftBinaryExpression)
|
|
throw new NullReferenceException(nameof(leftBinaryExpression));
|
|
|
|
if (Visit(expression.Right) is not Expression rightBinaryExpression)
|
|
throw new NullReferenceException(nameof(rightBinaryExpression));
|
|
|
|
if (leftBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.NamedValue && rightBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.Column)
|
|
{
|
|
var nv = (NamedValueExpression)leftBinaryExpression;
|
|
var c = (ColumnExpression)rightBinaryExpression;
|
|
|
|
leftBinaryExpression = new NamedValueExpression(nv.Name, c.QueryType, nv.Value);
|
|
}
|
|
else if (rightBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.NamedValue && leftBinaryExpression.NodeType == (ExpressionType)DatabaseExpressionType.Column)
|
|
{
|
|
var nv = (NamedValueExpression)rightBinaryExpression;
|
|
var c = (ColumnExpression)leftBinaryExpression;
|
|
|
|
rightBinaryExpression = new NamedValueExpression(nv.Name, c.QueryType, nv.Value);
|
|
}
|
|
|
|
return UpdateBinary(expression, leftBinaryExpression, rightBinaryExpression, expression.Conversion, expression.IsLiftedToNull, expression.Method);
|
|
}
|
|
|
|
protected override ColumnAssignment VisitColumnAssignment(ColumnAssignment ca)
|
|
{
|
|
ca = base.VisitColumnAssignment(ca);
|
|
|
|
var expression = ca.Expression;
|
|
var nv = expression as NamedValueExpression;
|
|
|
|
if (nv is not null)
|
|
expression = new NamedValueExpression(nv.Name, ca.Column.QueryType, nv.Value);
|
|
|
|
return UpdateColumnAssignment(ca, ca.Column, expression);
|
|
}
|
|
|
|
protected override Expression VisitConstant(ConstantExpression expression)
|
|
{
|
|
if (expression.Value is not null && !IsNumeric(expression.Value.GetType()))
|
|
{
|
|
var tv = new TypeValue(expression.Type, expression.Value);
|
|
|
|
if (!Map.TryGetValue(tv, out NamedValueExpression? nv))
|
|
{
|
|
var name = $"p{(Counter++)}";
|
|
|
|
nv = new NamedValueExpression(name, Language.TypeSystem.ResolveColumnType(expression.Type), expression);
|
|
|
|
Map.Add(tv, nv);
|
|
}
|
|
|
|
return nv;
|
|
}
|
|
|
|
return expression;
|
|
}
|
|
|
|
protected override Expression VisitParameter(ParameterExpression expression) => GetNamedValue(expression);
|
|
|
|
protected override Expression VisitMemberAccess(MemberExpression expression)
|
|
{
|
|
expression = (MemberExpression)base.VisitMemberAccess(expression);
|
|
|
|
var nv = expression.Expression as NamedValueExpression;
|
|
|
|
if (nv is not null)
|
|
{
|
|
var x = Expression.MakeMemberAccess(nv.Value, expression.Member);
|
|
|
|
return GetNamedValue(x);
|
|
}
|
|
|
|
return expression;
|
|
}
|
|
|
|
private Expression GetNamedValue(Expression expression)
|
|
{
|
|
var he = new HashedExpression(expression);
|
|
|
|
if (!ParameterMap.TryGetValue(he, out NamedValueExpression? nv))
|
|
{
|
|
var name = "$p{(iParam++)}";
|
|
|
|
nv = new NamedValueExpression(name, Language.TypeSystem.ResolveColumnType(expression.Type), expression);
|
|
|
|
ParameterMap.Add(he, nv);
|
|
}
|
|
|
|
return nv;
|
|
}
|
|
|
|
private static bool IsNumeric(Type type)
|
|
{
|
|
return Interop.TypeSystem.GetTypeCode(type) switch
|
|
{
|
|
TypeCode.Boolean or TypeCode.Byte or TypeCode.Decimal or TypeCode.Double or TypeCode.Int16 or TypeCode.Int32 or TypeCode.Int64
|
|
or TypeCode.SByte or TypeCode.Single or TypeCode.UInt16 or TypeCode.UInt32 or TypeCode.UInt64 => true,
|
|
_ => false,
|
|
};
|
|
}
|
|
}
|