Skip to content

fix(#343): reload relationships from database if included during POST #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

namespace JsonApiDotNetCore.Data
{
/// <inheritdoc />
public class DefaultEntityRepository<TEntity>
: DefaultEntityRepository<TEntity, int>,
IEntityRepository<TEntity>
Expand All @@ -26,6 +27,10 @@ public DefaultEntityRepository(
{ }
}

/// <summary>
/// Provides a default repository implementation and is responsible for
/// abstracting any EF Core APIs away from the service layer.
/// </summary>
public class DefaultEntityRepository<TEntity, TId>
: IEntityRepository<TEntity, TId>
where TEntity : class, IIdentifiable<TId>
Expand All @@ -48,7 +53,7 @@ public DefaultEntityRepository(
_genericProcessorFactory = _jsonApiContext.GenericProcessorFactory;
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual IQueryable<TEntity> Get()
{
if (_jsonApiContext.QuerySet?.Fields != null && _jsonApiContext.QuerySet.Fields.Count > 0)
Expand All @@ -57,41 +62,43 @@ public virtual IQueryable<TEntity> Get()
return _dbSet;
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual IQueryable<TEntity> Filter(IQueryable<TEntity> entities, FilterQuery filterQuery)
{
return entities.Filter(_jsonApiContext, filterQuery);
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual IQueryable<TEntity> Sort(IQueryable<TEntity> entities, List<SortQuery> sortQueries)
{
return entities.Sort(sortQueries);
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual async Task<TEntity> GetAsync(TId id)
{
return await Get().SingleOrDefaultAsync(e => e.Id.Equals(id));
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual async Task<TEntity> GetAndIncludeAsync(TId id, string relationshipName)
{
_logger.LogDebug($"[JADN] GetAndIncludeAsync({id}, {relationshipName})");

var result = await Include(Get(), relationshipName).SingleOrDefaultAsync(e => e.Id.Equals(id));
var includedSet = await IncludeAsync(Get(), relationshipName);
var result = await includedSet.SingleOrDefaultAsync(e => e.Id.Equals(id));

return result;
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual async Task<TEntity> CreateAsync(TEntity entity)
{
AttachRelationships();
_dbSet.Add(entity);

await _context.SaveChangesAsync();

return entity;
}

Expand Down Expand Up @@ -129,7 +136,7 @@ private void AttachHasOnePointers()
_context.Entry(relationship.Value).State = EntityState.Unchanged;
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual async Task<TEntity> UpdateAsync(TId id, TEntity entity)
{
var oldEntity = await GetAsync(id);
Expand All @@ -148,14 +155,14 @@ public virtual async Task<TEntity> UpdateAsync(TId id, TEntity entity)
return oldEntity;
}

/// </ inheritdoc>
/// <inheritdoc />
public async Task UpdateRelationshipsAsync(object parent, RelationshipAttribute relationship, IEnumerable<string> relationshipIds)
{
var genericProcessor = _genericProcessorFactory.GetProcessor<IGenericProcessor>(typeof(GenericProcessor<>), relationship.Type);
await genericProcessor.UpdateRelationshipsAsync(parent, relationship, relationshipIds);
}

/// </ inheritdoc>
/// <inheritdoc />
public virtual async Task<bool> DeleteAsync(TId id)
{
var entity = await GetAsync(id);
Expand All @@ -170,7 +177,8 @@ public virtual async Task<bool> DeleteAsync(TId id)
return true;
}

/// </ inheritdoc>
/// <inheritdoc />
[Obsolete("Use IncludeAsync")]
public virtual IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName)
{
var entity = _jsonApiContext.RequestEntity;
Expand All @@ -185,10 +193,57 @@ public virtual IQueryable<TEntity> Include(IQueryable<TEntity> entities, string
{
throw new JsonApiException(400, $"Including the relationship {relationshipName} on {entity.EntityName} is not allowed");
}

return entities.Include(relationship.InternalRelationshipName);
}

/// <inheritdoc />
public virtual async Task<IQueryable<TEntity>> IncludeAsync(IQueryable<TEntity> entities, string relationshipName)
{
var entity = _jsonApiContext.RequestEntity;
var relationship = entity.Relationships.FirstOrDefault(r => r.PublicRelationshipName == relationshipName);
if (relationship == null)
{
throw new JsonApiException(400, $"Invalid relationship {relationshipName} on {entity.EntityName}",
$"{entity.EntityName} does not have a relationship named {relationshipName}");
}

if (!relationship.CanInclude)
{
throw new JsonApiException(400, $"Including the relationship {relationshipName} on {entity.EntityName} is not allowed");
}

await ReloadPointerAsync(relationship);

return entities.Include(relationship.InternalRelationshipName);
}

/// </ inheritdoc>
/// <summary>
/// Ensure relationships on the provided entity have been fully loaded from the database.
/// </summary>
/// <remarks>
/// The only known case when this should be called is when a POST request is
/// sent with an ?include query.
///
/// See https://github.com/json-api-dotnet/JsonApiDotNetCore/issues/343
/// </remarks>
private async Task ReloadPointerAsync(RelationshipAttribute relationshipAttr)
{
if (relationshipAttr.IsHasOne && _jsonApiContext.HasOneRelationshipPointers.Get().TryGetValue(relationshipAttr, out var pointer))
{
await _context.Entry(pointer).ReloadAsync();
}

if (relationshipAttr.IsHasMany && _jsonApiContext.HasManyRelationshipPointers.Get().TryGetValue(relationshipAttr, out var pointers))
{
foreach (var hasManyPointer in pointers)
{
await _context.Entry(hasManyPointer).ReloadAsync();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is bad ☹️ ... need to reload the entire navigation property

Copy link
Contributor

@milosloub milosloub Aug 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all. Thank you, for this solution.
For HasMany relationships, what about something like:
_context.Entry(entity).Collection(relation).Load();

instead of:
_context.Entry(hasManyPointer).ReloadAsync();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I wasn't aware of the .Collection api. I think I like the approach I ended up taking better. The reason is that we should end up with a single SQL request that uses JOIN to load all the relationships rather than reloading each included relationship individually. I'll review it again later tonight to verify the difference.

}
}
}

/// <inheritdoc />
public virtual async Task<IEnumerable<TEntity>> PageAsync(IQueryable<TEntity> entities, int pageSize, int pageNumber)
{
if (pageNumber >= 0)
Expand All @@ -209,23 +264,23 @@ public virtual async Task<IEnumerable<TEntity>> PageAsync(IQueryable<TEntity> en
.ToListAsync();
}

/// </ inheritdoc>
/// <inheritdoc />
public async Task<int> CountAsync(IQueryable<TEntity> entities)
{
return (entities is IAsyncEnumerable<TEntity>)
? await entities.CountAsync()
: entities.Count();
}

/// </ inheritdoc>
/// <inheritdoc />
public async Task<TEntity> FirstOrDefaultAsync(IQueryable<TEntity> entities)
{
return (entities is IAsyncEnumerable<TEntity>)
? await entities.FirstOrDefaultAsync()
: entities.FirstOrDefault();
}

/// </ inheritdoc>
/// <inheritdoc />
public async Task<IReadOnlyList<TEntity>> ToListAsync(IQueryable<TEntity> entities)
{
return (entities is IAsyncEnumerable<TEntity>)
Expand Down
6 changes: 5 additions & 1 deletion src/JsonApiDotNetCore/Data/IEntityReadRepository.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
Expand All @@ -20,6 +21,9 @@ public interface IEntityReadRepository<TEntity, in TId>
/// </summary>
IQueryable<TEntity> Get();

[Obsolete("Use IncludeAsync")]
IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName);

/// <summary>
/// Include a relationship in the query
/// </summary>
Expand All @@ -28,7 +32,7 @@ public interface IEntityReadRepository<TEntity, in TId>
/// _todoItemsRepository.GetAndIncludeAsync(1, "achieved-date");
/// </code>
/// </example>
IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName);
Task<IQueryable<TEntity>> IncludeAsync(IQueryable<TEntity> entities, string relationshipName);

/// <summary>
/// Apply a filter to the provided queryable
Expand Down
27 changes: 22 additions & 5 deletions src/JsonApiDotNetCore/Services/EntityResourceService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ public virtual async Task<TResource> CreateAsync(TResource resource)

entity = await _entities.CreateAsync(entity);

// this ensures relationships get reloaded from the database if they have
// been requested
// https://github.com/json-api-dotnet/JsonApiDotNetCore/issues/343
if (ShouldIncludeRelationships())
return await GetWithRelationshipsAsync(entity.Id);

return MapOut(entity);
}

Expand All @@ -92,7 +98,7 @@ public virtual async Task<IEnumerable<TResource>> GetAsync()
entities = ApplySortAndFilterQuery(entities);

if (ShouldIncludeRelationships())
entities = IncludeRelationships(entities, _jsonApiContext.QuerySet.IncludedRelationships);
entities = await IncludeRelationshipsAsync(entities, _jsonApiContext.QuerySet.IncludedRelationships);

if (_jsonApiContext.Options.IncludeTotalRecordCount)
_jsonApiContext.PageManager.TotalRecords = await _entities.CountAsync(entities);
Expand Down Expand Up @@ -218,7 +224,8 @@ protected virtual IQueryable<TEntity> ApplySortAndFilterQuery(IQueryable<TEntity
return entities;
}

protected virtual IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> entities, List<string> relationships)
[Obsolete("Use IncludeRelationshipsAsync")]
protected IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> entities, List<string> relationships)
{
_jsonApiContext.IncludedRelationships = relationships;

Expand All @@ -228,14 +235,24 @@ protected virtual IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> e
return entities;
}

protected virtual async Task<IQueryable<TEntity>> IncludeRelationshipsAsync(IQueryable<TEntity> entities, List<string> relationships)
{
_jsonApiContext.IncludedRelationships = relationships;

foreach (var r in relationships)
entities = await _entities.IncludeAsync(entities, r);

return entities;
}

private async Task<TResource> GetWithRelationshipsAsync(TId id)
{
var query = _entities.Get().Where(e => e.Id.Equals(id));

_jsonApiContext.QuerySet.IncludedRelationships.ForEach(r =>
foreach (var r in _jsonApiContext.QuerySet.IncludedRelationships)
{
query = _entities.Include(query, r);
});
query = await _entities.IncludeAsync(query, r);
}

var value = await _entities.FirstOrDefaultAsync(query);

Expand Down
Loading