WCF Data Services to remove values from payload with EF provider

One of the more frequent questions about WCF Data Services is how to remove a property from the payload. To actually remove the property from the payload would require implementing IDataServiceMetadataProvider and there is a good blog about doing that here. I started down the road about 9 months ago before ditching it – too much copy and paste from the WCF Data Service core into my implementation and null projections plus a list that kept going on. I’m not going to show how to do that instead I’m going to leave the properties in the payload but not map them. This is probably better since the metadata will not change, but some people think differently.

I want to point out that I tried this about 9 month ago and ditched that as well but this great sample code from Derrick VanArnam showed me the light. Can’t say enough about that code without it I wouldn’t have been able to do this. Some of this code is a direct copy of his sample and some is close but with my tweaks. I’m not trying to pass his great work as mine.

First here is my setup. DataModel context that has a list of customers and their invoices.

public partial class DataModel : DbContext
{
   public DataModel()
       : base("name=DataModel")
   {
   }
    
   protected override void OnModelCreating(DbModelBuilder modelBuilder)
   {
        throw new UnintentionalCodeFirstException();
   }
    
   public IDbSet<Customers> { get; set; }
   public IDbSet<Invoices> { get; set; }
}

public partial class Customer
{
    public Customer()
    {
        this.Invoices = new HashSet();
    }
    
    public string CustomerId { get; set; }
    public string Company { get; set; }
    public string Contact { get; set; }
    public string Title { get; set; }
    public string Address { get; set; }
    public string City { get; set; }
    public string State { get; set; }
    public string Zip { get; set; }
    public string Phone { get; set; }
    public Nullable<CreditLimit> { get; set; }
    public Nullable<Balance> { get; set; }
    
    public virtual ICollection<Invoices> { get; set; }
}

public partial class Invoice
{
    public string InvoiceId { get; set; }
    public Nullable Invoiced { get; set; }
    public string CustomerId { get; set; }
    public string SalesPerson { get; set; }
    public Nullable<Amount> { get; set; }
    public Nullable<Discount> { get; set; }
    public Nullable<Paid> { get; set; }

    public virtual Customer Customers { getset; }
}

Here is my DataService

public class WcfDataService<DataModel> : DataService, IServiceProvider
{
    // This method is called only once to initialize service-wide policies.
    public static void InitializeService(DataServiceConfiguration config)
    {
        config.SetEntitySetAccessRule("*", EntitySetRights.AllRead);
        config.SetServiceOperationAccessRule("*", ServiceOperationRights.All);
        config.DataServiceBehavior.MaxProtocolVersion = DataServiceProtocolVersion.V3;
        config.UseVerboseErrors = true;
    }

    private readonly DataModel _dataSource;
    private readonly WCFEFProvider _provider;

    public WcfDataService()
    {
        _dataSource = new DataModel();
        _provider = new WCFEFProvider(this, _dataSource);
    }

    public object GetService(Type serviceType)
    {
        if (serviceType.IsInstanceOfType(_provider))
        {
            return _provider;
        }
        return null;
    }

    protected override DataModel CreateDataSource()
    {
        return _dataSource;
    }
}

The WCFEFProvider is my implementation of the AdventureWorksEFProvider in MS samplecode. Also I called my interface IQueryWrapper instead of IObjectQueryWrapper and I just exposed the IQueryable as that’s all that was needed.

public interface IQueryWrapper
{
    IQueryable Query { get; }
}

public class WCFEFProvider<T> : EntityFrameworkDataServiceProvider
{
    /// <param name="service">Provider service</param>
    /// <param name="container">Entity container</param>
    public WCFEFProvider(object service, T container)
        : base(new DataServiceProviderArgs(service, container, null, false))
    {
    }

    /// <summary>
    /// Override the query root
    /// </summary>
    /// <param name="resourceSet"></param>
    /// <returns></returns>
    public override IQueryable GetQueryRootForResourceSet(ResourceSet resourceSet)
    {
        // Parameterize the expression tree
        return WrappedQueryProvider.CreateQuery(base.GetQueryRootForResourceSet(resourceSet));
    }

    /// <summary>
    /// Override the get resource to get underlying ObjectQuery for base.GetResource
    /// </summary>
    /// <param name="query"></param>
    /// <param name="fullTypeName"></param>
    /// <returns></returns>
    public override object GetResource(IQueryable query, string fullTypeName)
    {
        var queryWrapper = query as IQueryWrapper;
        if (queryWrapper != null)
        {
            query = queryWrapper.Query;
        }
        return base.GetResource(query, fullTypeName);
    }

    // We will need to create subclasses for EF to not complain but we need to tell WCF what the ResourceType the subclass is
    public override ResourceType GetResourceType(object target)
    {
        var type = target.GetType();
        var mainclass =
            SecurityProjection.SubClasses.Where(kv => kv.Value == type).Select(kv => kv.Key).FirstOrDefault();
        if (mainclass != null)
        {
            return base.GetResourceType(Activator.CreateInstance(mainclass));
        }
        return base.GetResourceType(target);
    }
}

As a note the GetResource override was what I never figured out 9 months ago and caused me to bail on it. I also rename EFParameterizedQueryProvider to WrappedQueryProvider, EFParameterizedQuery to WrappedIQueryable and made some small tweaks to both.

public class WrappedQueryProvider : IQueryProvider
{
    /// <summary>
    /// Cache the CreateEFParameterizedQuery generic methodinfo
    /// </summary>
    readonly static MethodInfo CreateQueryMethod =
        typeof(WrappedQueryProvider).GetMethod("CreateWrappedIQueryable",
            BindingFlags.Instance | BindingFlags.NonPublic);

    /// <summary>
    /// The underlying Entity Framework query provider
    /// </summary>
    private readonly IQueryProvider _underlyingQueryProvider;

    public WrappedQueryProvider(IQueryProvider underlyingQueryProvider)
    {
        _underlyingQueryProvider = underlyingQueryProvider;
    }

    public static IQueryable CreateQuery(IQueryable underlyingQuery)
    {
       //Wrap it so we can intercept it
        var provider = new WrappedQueryProvider(underlyingQuery.Provider);
        Type elementType = underlyingQuery.Expression.Type.GetQueryElementType();
        return (IQueryable)CreateQueryMethod.MakeGenericMethod(elementType).Invoke(provider, new object[] { underlyingQuery.Expression, underlyingQuery });
    }

    public IQueryable CreateQuery(Expression expression)
    {
        Type elementType = expression.Type.GetQueryElementType();
        return (IQueryable)CreateQueryMethod.MakeGenericMethod(elementType).Invoke(this, new object[] { expression, null });
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        return CreateWrappedIQueryable<TElement>(expression, null);
    }

    private WrappedIQueryable<TElement> CreateWrappedIQueryable<TElement>(Expression expression, IQueryable queryable)
    {
        var objectQuery = queryable as ObjectQuery<TElement> ??
                          (ObjectQuery<TElement>)_underlyingQueryProvider.CreateQuery<TElement>(expression);
        return new WrappedIQueryable<TElement>(objectQuery, this);
    }

    public object Execute(Expression expression)
    {
        //Here we will modify the expressions 
        var securedExpression = new SecurityProjection().CheckSecurity(expression);
        var parameterdExpression = new ParameterizeExpressionVisitor().Parameterize(securedExpression);
        if (typeof(IQueryable).IsAssignableFrom(expression.Type))
        {
            return _underlyingQueryProvider.CreateQuery(parameterdExpression);
        }
        return _underlyingQueryProvider.Execute(parameterdExpression);
    }

    public TResult Execute<TResult>(Expression expression)
    {
        return (TResult)Execute(expression);
    }
}

public class WrappedIQueryable<T> : IOrderedQueryable<T>, IQueryWrapper
{
    /// <summary>
    /// The original Entity Framework ObjectQuery
    /// </summary>
    private readonly IQueryable _queryable;

    /// <summary>
    /// The Entity Framework query provider
    /// </summary>
    private readonly IQueryProvider _queryProvider;

    public WrappedIQueryable(IQueryable<T> objectQuery, IQueryProvider queryProvider)
    {
        _queryProvider = queryProvider;
        _queryable = objectQuery;
    }

    public IEnumerator<T> GetEnumerator()
    {
        return _queryProvider.Execute<IEnumerable<T>>(Expression).GetEnumerator();
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }

    public Type ElementType
    {
        get { return typeof (T); }
    }

    public Expression Expression
    {
        get { return _queryable.Expression; }
    }

    public IQueryProvider Provider
    {
        get { return _queryProvider; }
    }

    IQueryable IQueryWrapper.Query
    {
        get { return _queryable; }
    }
}

You will also need the TypeExtension class from MS sample code (I took it as is).  And since I have all that I took the EFParameterizedExpressionVisitor and called into that to get the benefit of EF parameterizing my OData calls. All the code above code is setup so we can now intercept the Expression Trees.  Now we get into the meat of the issue.

To project I want to do something l like

context.Customers.Select(c=> new Customer()
   {
       CustomerId = c.CustomerId,
       Address = c.Address,
       City = c.City,
       Company = c.Company,
       Contact = c.Contact,
       Phone = c.Phone,
       State = c.State
   });

But Entity Framework will not allow to project to an entity.  The solution I’ve found is it will allow you to project to a subclass

public class CustomersSub : Customers
{
}

// can project
context.Customers.Select(c=> new CustomersSub ()
   {
       CustomerId = c.CustomerId,
       Address = c.Address,
       City = c.City,
       Company = c.Company,
       Contact = c.Contact,
       Phone = c.Phone,
       State = c.State
   });

I also don’t want to create a bunch of empty subclasses by hand or even have T4 template do it.  I’m going to use the TypeBuilder to do it at runtime then cache the result.

public class SecurityProjection : ExpressionVisitor
{
    public static readonly ConcurrentDictionary<Type, Type> SubClasses = new ConcurrentDictionary<Type, Type>();
    private static readonly ModuleBuilder _moduleBuilder;

    // These will need to be changed based on your requirements of when properties are removed or not
    protected static readonly IDictionary<Type, IList<string>> removeProperties = new Dictionary<Type, IList<string>>();

    static SecurityProjection()
    {
        // Will need these to create subclasses on the fly
        var assemblyName = new AssemblyName("SecurityProjectionAssembly");
        var assemblyBuilder = AppDomain.CurrentDomain.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run);
        _moduleBuilder = assemblyBuilder.DefineDynamicModule("SecurityProjectionModule");

        // always remove the credit limit and balance & sales person
        removeProperties.Add(typeof(Customer), new List<string>());
        removeProperties[typeof(Customer)].Add("CreditLimit");
        removeProperties[typeof(Customer)].Add("Balance");
        removeProperties.Add(typeof(Invoice), new List<string>());
        removeProperties[typeof(Invoice)].Add("SalesPerson");

    }

    public Expression CheckSecurity(Expression expression)
    {
        expression = Visit(expression);
        return expression;
    }

    protected override MemberBinding VisitMemberBinding(MemberBinding node)
    {
        var memberAssignment = node as MemberAssignment;
        if (memberAssignment != null)
        {
            var memType = memberAssignment.Expression.Type;
            if (memType != null)
            {
                if (removeProperties.ContainsKey(memType))
                {
                    // make a subclass for the projection as EF will not allow you project to entity
                    var to = SubClasses.GetOrAdd(memType, CreateSubClass);

                    var projection = Project(memType, to, memberAssignment.Expression);
                    var binder = Expression.Bind(memberAssignment.Member, projection);
                    return binder;
                }
                memType = memType.GetQueryElementType();
                if (memType !- null && removeProperties.ContainsKey(memType))
                {
                    // make a subclass for the projection as EF will not allow you project to entity
                    var to = SubClasses.GetOrAdd(memType, CreateSubClass);

                    // parameter of the expression
                    var source = Expression.Parameter(memType, "source");
                    var projection = Project(memType, to, source);

                    var func = Expression.Lambda(typeof(Func<,>).MakeGenericType(memType, to), projection, source);

                    var result = Expression.Call(typeof(Enumerable), "Select", new[] { memType, to }, memberAssignment.Expression, func);
                    var binder = Expression.Bind(memberAssignment.Member, result);
                    return binder;
                }
            }
        }

        return base.VisitMemberBinding(node);
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        var methodType = CheckIQueryable(node.Method.ReturnType);
        if (methodType != null && removeProperties.ContainsKey(methodType))
        {
            // make a subclass for the projection as EF will not allow you project to entity
            var to = SubClasses.GetOrAdd(methodType, CreateSubClass);
            // parameter of the expression
            var source = Expression.Parameter(methodType, "source");
            var projection = Project(methodType, to, source);

            var func = Expression.Lambda(typeof (Func<,>).MakeGenericType(methodType, to), projection, source);

            var result = Expression.Call(typeof (Queryable), "Select", new[] {methodType, to}, node, func);
            return result;
        }

        return base.VisitMethodCall(node);
    }

    private Type CheckIQueryable(Type type)
    {
        return
            type.GetInterfaces()
                .Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof (IQueryable<>))
                .Select(i => i.GetGenericArguments().First())
                .FirstOrDefault();
    }

    private Type CreateSubClass(Type type)
    {
        var subclass = _moduleBuilder.DefineType(type.Name + "SubClass", type.Attributes, type);
        var newType = subclass.CreateType();
        return newType;
    }

    private Expression Project(Type from, Type to, Expression pSource)
    {
        var ignoreProps = removeProperties[from];

        var bindings = GetProperties(from, false)
                    .Join(GetProperties(to, true), source => new {source.Name, source.PropertyType},
                            dest => new {dest.Name, dest.PropertyType}, (source, dest) => new {source, dest})
                    .Where(a => !ignoreProps.Contains(a.source.Name))
                    .Select(prop => Expression.Bind(prop.dest, Expression.Property(pSource, prop.source)))
                    .ToList();
        return Expression.MemberInit(Expression.New(to), bindings);
    }

    private IEnumerable<PropertyInfo> GetProperties(Type type, bool write)
    {
        var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance).AsEnumerable();
        if (write)
        {
            properties = properties.Where(p => p.CanWrite);
        }
        else
        {
            properties = properties.Where(p => p.CanRead);
        }

        return properties;
    }
}

Something’s to explain. I created a static IDictionary> to hold the property names I wanted to remove from the projection for that type. I have it hard coded to always remove the CreditLimit and Balance from the Customer and from the invoice to not show the sales person. The visit method call we are looking for the root query to do the projection similar above to how I wrote it out by hand. Then VisitMemberBinding to look for any types that are either the type we are looking for or IEnumberable of the type we are looking for. If it’s just the type we create a new subclass type in it’s place. For IEnumberables we do a select to project it to our subclass. This is why we needed to create an empty entity to pass down in the WCFEFProvider GetResourceType method. Otherwise WCF will be looking for an entity of the type of the subclass and not the original entity.

It would probably be worth caching the property of each type instead of using reflection each time.

Now when hitting the webservice Invoices will still return the sales person but it will be null and the same for balance and credit limit for the customer. I’ve tested this with projection, expand and it all seems to be working but I can’t say it’s 100% bullet proof. 🙂

Tags: , ,

Friday, April 25th, 2014 OData