Alternatives to using a formula for calculated properties with NHibernate using Linq

Tags: linq, nhibernate, calculated, properties, ddd

Tags: linq

Our problem is basic enough. We want to query using reusable business logic rather than having to re-implement it everywhere.

Lets say we have a Person record like this:

public class Person
{
    public virtual Guid Id { get; set; }
    public virtual string Name { get; set; }
    public virtual DateTime StartDate { get; set; }
    public virtual DateTime? EndDate { get; set; }
} 

Often we'd like to check if this person is currently employed, and we can do this by doing this:

var activeEmployees = session.Query<Person>().Where(p => p.StartDate <= DateTime.Today && (p.EndDate == null  || p.EndDate >= DateTime.Today); 

But what we really would like to type is something like this:

var activeEmployees = session.Query<Person>().Where(p => p.IsActive); 

And leave the underlying logic to somewhere else. This would also let us expose the IsActive to for example a Linq-aware grid and allow the user to filter for only active employees instead of maintaing the filter above. In 90% of the cases, this is a filter the user would use and remove only when special needs arise.

One way to do the above is to create a property in the class with that logic, but then we'd need a formula in the mapping that repeats the logic, something like this:

public class PersonMap : ClassMapping<Person>
{

    public PersonMap()
    {
        Id(p => p.Id);
        Property(p => p.Name);
        Property(p => p.StartDate);
        Property(p => p.EndDate);
        Property(p => p.IsActive,
                 m => { m.Formula("StartDate <= CURRENT_TIMESTAMP AND EndDate >= ISNULL(EndDate, '2099-01-01')"); });
    }
} 

and the class itself:

public class Person
{
    public virtual Guid Id { get; set; }
    public virtual string Name { get; set; }
    public virtual DateTime StartDate { get; set; }
    public virtual DateTime? EndDate { get; set; }
    public virtual bool IsActive
    {
        get { return StartDate <= DateTime.Today && (EndDate == null || EndDate >= DateTime.Today); }
    }
}

The problem now is that it is not wery testable. We're using both DateTime.Today and CURRENTTIMESTAMP. The DateTime.Today can easily be replaced by a call to a more testable function that allows us to define Today in unit tests, but the CURRENTTIMESTAMP is not so easy if we want to use in memory tests with sqlite for example. An alternative solution is described here:

https://hendryluk.wordpress.com/2011/09/06/nhibernate-linq-ing-calculated-properties/

The problem is that it's a bit late to apply as NHibernate has already parsed the expression before getting this far, so for example if our expression is

var IsActiveExpression = person => person.StartDate <= DateTime.Today && (person.EndDate == null  || person.EndDate >= DateTime.Today; 

but query is this:

var activeEmployees = session.Query<Person>().Where(p => p.IsActive); 

then it will fail because p != person. There are also multiple other possible falures here. The alternative is to inject the IsActiveExpression into the query much earlier. To do this, I've created a new IQueryProvider for NHibernate

public class ExpressionUnpackQueryProvider : DefaultQueryProvider { public ExpressionUnpackQueryProvider(ISessionImplementor session) : base(session) { }

    public override object Execute(Expression expression)
    {
        var visitor = new ReplacePropertyWithExpressionByConvention();

        var exp = visitor.Visit(expression);

        return base.Execute(exp);
    }

    public override object ExecuteFuture(Expression expression)
    {
        var visitor = new ReplacePropertyWithExpressionByConvention();

        var exp = visitor.Visit(expression);

        return base.ExecuteFuture(exp);
    }
}

unfortunately it's not possible to tell NHibernate to use this implementation yet, but in the future it will have a property that can be set to the class name of the new implementation.

In the meantime I'll have to type

session.QueryExtended<Person>()

instead of session.Query() by using this extension method:

public static class SessionLinqExtensions
{
    public static IQueryable<T> QueryExtended<T>(this ISession session)
    {
        return new NhQueryable<T>(new ExpressionUnpackQueryProvider(session.GetSessionImplementation()),
            Expression.Constant(new NhQueryable<T>(session.GetSessionImplementation())));
    }
}

public class NhExtendedQueryable<T> : QueryableBase<T>
{
    // This constructor is called by our users, create a new IQueryExecutor. public NhExtendedQueryable(IQueryProvider provider) : base(provider) { } // This constructor is called indirectly by LINQ's query methods, just pass to base. public NhExtendedQueryable(IQueryProvider provider, Expression expression) : base(provider, expression) { } }

}

The code that does all the heavy lifting:

public class ReplacePropertyWithExpressionByConvention : ExpressionVisitor
{
    protected override Expression VisitMemberAccess(MemberExpression m)
    {
        var expression = GetExpressionField(m.Member) as LambdaExpression;

        if (expression != null)
        {
            var childVisitor = new VariableRenameVisitor(m.Expression as ParameterExpression, expression.Parameters.FirstOrDefault());

            var exp = childVisitor.Visit(expression.Body);

            return exp;
        }

        return base.VisitMemberAccess(m);
    }

    protected virtual LambdaExpression GetExpressionField(MemberInfo m)
    {
        if (m.MemberType != MemberTypes.Property)
            return null;

        if (m.DeclaringType == null)
            return null;

        var name = m.Name + "Expression";
        var staticMember = m.DeclaringType.GetField(name, BindingFlags.Static | BindingFlags.Public);

        if (staticMember != null)
        {
            var expression = staticMember.GetValue(null) as LambdaExpression;

            if (expression == null)
            {
                throw new InvalidOperationException(
                    string.Format(
                        "The expression named {0} must be a lambda to be usable as a Linq expression for property {1} on type {2} ",
                        name, m.Name, m.DeclaringType.Name));
            }

            var p = (PropertyInfo) m;

            var expectedType = typeof (Func&lt;,&gt;).MakeGenericType(m.DeclaringType, p.PropertyType);

            if (!expectedType.IsAssignableFrom(expression.Type))
            {
                throw new InvalidOperationException(
                    string.Format(
                        "The expression named {0} must be of type Func&lt;{2},{3}&gt;  to be usable as a Linq expression for property {1} on type {2} ",
                        name, m.Name, m.DeclaringType.Name, p.PropertyType.Name));
            }

            return expression;
        }

        return null;
    }
}

and

public class VariableRenameVisitor : ExpressionVisitor
{
    private readonly Expression _rewriteTo;
    private readonly ParameterExpression _rewriteFrom;

    public VariableRenameVisitor(Expression rewriteTo, ParameterExpression rewriteFrom)
    {
        _rewriteTo = rewriteTo;
        _rewriteFrom = rewriteFrom;
    }

    protected override Expression VisitMemberAccess(MemberExpression m)
    {
        if (m.Expression == _rewriteFrom)
        {
            return Expression.MakeMemberAccess(_rewriteTo, m.Member);
        }

        return base.VisitMemberAccess(m);
    }
}

both are using the Microsoft suggested ExpressionVisitor. Code below. Now I can use this:

public class Person
{
    public virtual Guid Id { get; set; }
    public virtual string Name { get; set; }
    public virtual DateTime StartDate { get; set; }
    public virtual DateTime? EndDate { get; set; }

    public static Expression<Func<Person, bool>> IsActiveExpression =
        person => person.StartDate <= DateTimeTestable.Today() && (person.EndDate == null || person.EndDate >= DateTimeTestable.Today());

    public static Func<Person, bool> CompiledIsActive = IsActiveExpression.Compile(); 

    public virtual bool IsActive { get { return CompiledIsActive(this); } }<br />
}

and

var activeEmployees = session.QueryExtended<Person>().Where(p => p.IsActive);

as the IsActiveExpression will be injected instead of p.Isactive by convention. All I have to do is create a property with a correspondingly named Expression<>> field named Expression.

The ExpressionVisitor:

public class ExpressionVisitor
{
    protected ExpressionVisitor()
    {
    }

    public virtual Expression Visit(Expression exp)
    {
        if (exp == null)
            return null;

        switch (exp.NodeType)
        {
            case ExpressionType.Negate:
            case ExpressionType.NegateChecked:
            case ExpressionType.Not:
            case ExpressionType.Convert:
            case ExpressionType.ConvertChecked:
            case ExpressionType.ArrayLength:
            case ExpressionType.Quote:
            case ExpressionType.TypeAs:
                return this.VisitUnary((UnaryExpression)exp);
            case ExpressionType.Add:
            case ExpressionType.AddChecked:
            case ExpressionType.Subtract:
            case ExpressionType.SubtractChecked:
            case ExpressionType.Multiply:
            case ExpressionType.MultiplyChecked:
            case ExpressionType.Divide:
            case ExpressionType.Modulo:
            case ExpressionType.And:
            case ExpressionType.AndAlso:
            case ExpressionType.Or:
            case ExpressionType.OrElse:
            case ExpressionType.LessThan:
            case ExpressionType.LessThanOrEqual:
            case ExpressionType.GreaterThan:
            case ExpressionType.GreaterThanOrEqual:
            case ExpressionType.Equal:
            case ExpressionType.NotEqual:
            case ExpressionType.Coalesce:
            case ExpressionType.ArrayIndex:
            case ExpressionType.RightShift:
            case ExpressionType.LeftShift:
            case ExpressionType.ExclusiveOr:
                return this.VisitBinary((BinaryExpression)exp);
            case ExpressionType.TypeIs:
                return this.VisitTypeIs((TypeBinaryExpression)exp);
            case ExpressionType.Conditional:
                return this.VisitConditional((ConditionalExpression)exp);
            case ExpressionType.Constant:
                return this.VisitConstant((ConstantExpression)exp);
            case ExpressionType.Parameter:
                return this.VisitParameter((ParameterExpression)exp);
            case ExpressionType.MemberAccess:
                return this.VisitMemberAccess((MemberExpression)exp);
            case ExpressionType.Call:
                return this.VisitMethodCall((MethodCallExpression)exp);
            case ExpressionType.Lambda:
                return this.VisitLambda((LambdaExpression)exp);
            case ExpressionType.New:
                return this.VisitNew((NewExpression)exp);
            case ExpressionType.NewArrayInit:
            case ExpressionType.NewArrayBounds:
                return this.VisitNewArray((NewArrayExpression)exp);
            case ExpressionType.Invoke:
                return this.VisitInvocation((InvocationExpression)exp);
            case ExpressionType.MemberInit:
                return this.VisitMemberInit((MemberInitExpression)exp);
            case ExpressionType.ListInit:
                return this.VisitListInit((ListInitExpression)exp);
            default:
                throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
        }
    }

    protected virtual MemberBinding VisitBinding(MemberBinding binding)
    {
        switch (binding.BindingType)
        {
            case MemberBindingType.Assignment:
                return this.VisitMemberAssignment((MemberAssignment)binding);
            case MemberBindingType.MemberBinding:
                return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
            case MemberBindingType.ListBinding:
                return this.VisitMemberListBinding((MemberListBinding)binding);
            default:
                throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
        }
    }

    protected virtual ElementInit VisitElementInitializer(ElementInit initializer)
    {
        ReadOnlyCollection&lt;Expression&gt; arguments = this.VisitExpressionList(initializer.Arguments);
        if (arguments != initializer.Arguments)
        {
            return Expression.ElementInit(initializer.AddMethod, arguments);
        }
        return initializer;
    }

    protected virtual Expression VisitUnary(UnaryExpression u)
    {
        Expression operand = this.Visit(u.Operand);
        if (operand != u.Operand)
        {
            return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
        }
        return u;
    }

    protected virtual Expression VisitBinary(BinaryExpression b)
    {
        Expression left = this.Visit(b.Left);
        Expression right = this.Visit(b.Right);
        Expression conversion = this.Visit(b.Conversion);
        if (left != b.Left || right != b.Right || conversion != b.Conversion)
        {
            if (b.NodeType == ExpressionType.Coalesce &amp;&amp; b.Conversion != null)
                return Expression.Coalesce(left, right, conversion as LambdaExpression);
            else
                return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
        }
        return b;
    }

    protected virtual Expression VisitTypeIs(TypeBinaryExpression b)
    {
        Expression expr = this.Visit(b.Expression);
        if (expr != b.Expression)
        {
            return Expression.TypeIs(expr, b.TypeOperand);
        }
        return b;
    }

    protected virtual Expression VisitConstant(ConstantExpression c)
    {
        return c;
    }

    protected virtual Expression VisitConditional(ConditionalExpression c)
    {
        Expression test = this.Visit(c.Test);
        Expression ifTrue = this.Visit(c.IfTrue);
        Expression ifFalse = this.Visit(c.IfFalse);
        if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse)
        {
            return Expression.Condition(test, ifTrue, ifFalse);
        }
        return c;
    }

    protected virtual Expression VisitParameter(ParameterExpression p)
    {
        return p;
    }

    protected virtual Expression VisitMemberAccess(MemberExpression m)
    {
        Expression exp = this.Visit(m.Expression);
        if (exp != m.Expression)
        {
            return Expression.MakeMemberAccess(exp, m.Member);
        }
        return m;
    }

    protected virtual Expression VisitMethodCall(MethodCallExpression m)
    {
        Expression obj = this.Visit(m.Object);
        IEnumerable&lt;Expression&gt; args = this.VisitExpressionList(m.Arguments);
        if (obj != m.Object || args != m.Arguments)
        {
            return Expression.Call(obj, m.Method, args);
        }
        return m;
    }

    protected virtual ReadOnlyCollection&lt;Expression&gt; VisitExpressionList(ReadOnlyCollection&lt;Expression&gt; original)
    {
        List&lt;Expression&gt; list = null;
        for (int i = 0, n = original.Count; i &lt; n; i++)
        {
            Expression p = this.Visit(original[i]);
            if (list != null)
            {
                list.Add(p);
            }
            else if (p != original[i])
            {
                list = new List&lt;Expression&gt;(n);
                for (int j = 0; j &lt; i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(p);
            }
        }
        if (list != null)
        {
            return list.AsReadOnly();
        }
        return original;
    }

    protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment)
    {
        Expression e = this.Visit(assignment.Expression);
        if (e != assignment.Expression)
        {
            return Expression.Bind(assignment.Member, e);
        }
        return assignment;
    }

    protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding)
    {
        IEnumerable&lt;MemberBinding&gt; bindings = this.VisitBindingList(binding.Bindings);
        if (bindings != binding.Bindings)
        {
            return Expression.MemberBind(binding.Member, bindings);
        }
        return binding;
    }

    protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding)
    {
        IEnumerable&lt;ElementInit&gt; initializers = this.VisitElementInitializerList(binding.Initializers);
        if (initializers != binding.Initializers)
        {
            return Expression.ListBind(binding.Member, initializers);
        }
        return binding;
    }

    protected virtual IEnumerable&lt;MemberBinding&gt; VisitBindingList(ReadOnlyCollection&lt;MemberBinding&gt; original)
    {
        List&lt;MemberBinding&gt; list = null;
        for (int i = 0, n = original.Count; i &lt; n; i++)
        {
            MemberBinding b = this.VisitBinding(original[i]);
            if (list != null)
            {
                list.Add(b);
            }
            else if (b != original[i])
            {
                list = new List&lt;MemberBinding&gt;(n);
                for (int j = 0; j &lt; i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(b);
            }
        }
        if (list != null)
            return list;
        return original;
    }

    protected virtual IEnumerable&lt;ElementInit&gt; VisitElementInitializerList(ReadOnlyCollection&lt;ElementInit&gt; original)
    {
        List&lt;ElementInit&gt; list = null;
        for (int i = 0, n = original.Count; i &lt; n; i++)
        {
            ElementInit init = this.VisitElementInitializer(original[i]);
            if (list != null)
            {
                list.Add(init);
            }
            else if (init != original[i])
            {
                list = new List&lt;ElementInit&gt;(n);
                for (int j = 0; j &lt; i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(init);
            }
        }
        if (list != null)
            return list;
        return original;
    }

    protected virtual Expression VisitLambda(LambdaExpression lambda)
    {
        Expression body = this.Visit(lambda.Body);
        if (body != lambda.Body)
        {
            return Expression.Lambda(lambda.Type, body, lambda.Parameters);
        }
        return lambda;
    }

    protected virtual NewExpression VisitNew(NewExpression nex)
    {
        IEnumerable&lt;Expression&gt; args = this.VisitExpressionList(nex.Arguments);
        if (!Equals(args, nex.Arguments))
        {
            if (nex.Members != null)
                return Expression.New(nex.Constructor, args, nex.Members);
            else
                return Expression.New(nex.Constructor, args);
        }
        return nex;
    }

    protected virtual Expression VisitMemberInit(MemberInitExpression init)
    {
        NewExpression n = this.VisitNew(init.NewExpression);
        IEnumerable&lt;MemberBinding&gt; bindings = this.VisitBindingList(init.Bindings);
        if (n != init.NewExpression || bindings != init.Bindings)
        {
            return Expression.MemberInit(n, bindings);
        }
        return init;
    }

    protected virtual Expression VisitListInit(ListInitExpression init)
    {
        NewExpression n = this.VisitNew(init.NewExpression);
        IEnumerable&lt;ElementInit&gt; initializers = this.VisitElementInitializerList(init.Initializers);
        if (n != init.NewExpression || initializers != init.Initializers)
        {
            return Expression.ListInit(n, initializers);
        }
        return init;
    }

    protected virtual Expression VisitNewArray(NewArrayExpression na)
    {
        IEnumerable&lt;Expression&gt; exprs = this.VisitExpressionList(na.Expressions);
        if (exprs != na.Expressions)
        {
            if (na.NodeType == ExpressionType.NewArrayInit)
            {
                return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
            }
            else
            {
                return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
            }
        }
        return na;
    }

    protected virtual Expression VisitInvocation(InvocationExpression iv)
    {
        IEnumerable&lt;Expression&gt; args = this.VisitExpressionList(iv.Arguments);
        Expression expr = this.Visit(iv.Expression);
        if (args != iv.Arguments || expr != iv.Expression)
        {
            return Expression.Invoke(expr, args);
        }
        return iv;
    }
}</code></pre>

1 Comment

Add a Comment