diff --git a/GraphDiff/GraphDiff.Tests/Tests/QueryLoaderBehaviours.cs b/GraphDiff/GraphDiff.Tests/Tests/QueryLoaderBehaviours.cs index a752cae..b71a831 100644 --- a/GraphDiff/GraphDiff.Tests/Tests/QueryLoaderBehaviours.cs +++ b/GraphDiff/GraphDiff.Tests/Tests/QueryLoaderBehaviours.cs @@ -2,6 +2,9 @@ using System.Data.Entity.Infrastructure; using Microsoft.VisualStudio.TestTools.UnitTesting; using RefactorThis.GraphDiff.Tests.Models; +using System; +using System.Data.Entity; +using System.Linq; namespace RefactorThis.GraphDiff.Tests.Tests { @@ -95,5 +98,65 @@ public void ShouldPerformMutlipleQueriesWhenRequested() // TODO how do I test number of queries.. } + + [TestMethod] + public void ShouldPrametrizeLoadingQuery() + { + using (var context = new TestDbContext()) + { + context.Set().Add(_oneToOneAssociated); + context.Set().Add(_oneToManyAssociated); + context.Nodes.Add(_node); + context.SaveChanges(); + } + + var x = new LocalDbConnectionFactory("v11.0"); + var connection = x.CreateConnection("GraphDiff"); + + using (var context = new TestDbContext(connection)) + { + using (var logCollector = new QueryLogCollector(context)) + { + // Setup mapping + context.UpdateGraph( + entity: _node, + mapping: map => map.OwnedEntity(p => p.OneToOneOwned, with => with + .OwnedEntity(p => p.OneToOneOneToOneOwned) + .AssociatedEntity(p => p.OneToOneOneToOneAssociated) + .OwnedCollection(p => p.OneToOneOneToManyOwned) + .AssociatedCollection(p => p.OneToOneOneToManyAssociated)), + updateParams: new UpdateParams { QueryMode = QueryMode.SingleQuery }); + + Assert.IsTrue(logCollector.Logs.Any(l => l.Contains("@p__linq__0")), + "Can't find a parameter in the loading query"); + } + } + + // TODO how do I test number of queries.. + } + } + + internal class QueryLogCollector : IDisposable + { + private readonly DbContext _context; + private readonly Action _originalLogger; + private readonly List _logs = new List(); + + public List Logs + { + get { return _logs; } + } + + public QueryLogCollector(DbContext context) + { + _context = context; + _originalLogger = _context.Database.Log; + _context.Database.Log = line => _logs.Add(line); + } + + public void Dispose() + { + _context.Database.Log = _originalLogger; + } } } diff --git a/GraphDiff/GraphDiff/Internal/QueryLoader.cs b/GraphDiff/GraphDiff/Internal/QueryLoader.cs index 7d37395..cc70911 100644 --- a/GraphDiff/GraphDiff/Internal/QueryLoader.cs +++ b/GraphDiff/GraphDiff/Internal/QueryLoader.cs @@ -78,7 +78,31 @@ private Expression> CreateKeyPredicateExpression(T entity) private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter) { - return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType)); + return Expression.Equal(Expression.Property(parameter, keyProperty), + ExpressionParameterHelper.GetParameter(keyProperty.GetValue(entity, null), + keyProperty.PropertyType)); } } + + internal class ExpressionParameterHelper + { + public static MemberExpression GetParameter(object value, Type type) + { + MethodInfo method = typeof(ExpressionParameterHelper).GetMethod("GetParameterInternal", BindingFlags.Static | BindingFlags.NonPublic); + MethodInfo genericMethod = method.MakeGenericMethod(type); + return (MemberExpression)genericMethod.Invoke(null, new object[] { value }); + } + + private static MemberExpression GetParameterInternal(TValue value) + { + var closure = new ExpressionParameterField { Value = value }; + return Expression.Field(Expression.Constant(closure), "Value"); + } + + private class ExpressionParameterField + { + public T Value; + } + } + }