Skip to content

Commit fe5c710

Browse files
authored
add ClaimsPrincipal User to filters
1 parent a535067 commit fe5c710

12 files changed

Lines changed: 48 additions & 45 deletions

src/Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<Project>
33
<PropertyGroup>
44
<NoWarn>CS1591;NU5104;CS1573</NoWarn>
5-
<Version>18.1.0</Version>
5+
<Version>19.0.0</Version>
66
<AssemblyVersion>1.0.0</AssemblyVersion>
77
<PackageTags>EntityFrameworkCore, EntityFramework, GraphQL</PackageTags>
88
<SignAssembly>false</SignAssembly>

src/GraphQL.EntityFramework/ConnectionConverter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ static async Task<Connection<TItem>> Range<TSource, TItem>(
172172
var page = list.Skip(skip).Take(take);
173173
QueryLogger.Write(page);
174174
IEnumerable<TItem> result = await page.ToListAsync(cancellation);
175-
result = await filters.ApplyFilter(result, context.UserContext);
175+
result = await filters.ApplyFilter(result, context.UserContext, context.User);
176176

177177
cancellation.ThrowIfCancellationRequested();
178178
return Build(skip, take, count, result);

src/GraphQL.EntityFramework/Filters/Filters.cs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1-
namespace GraphQL.EntityFramework;
1+
using System.Security.Claims;
2+
3+
namespace GraphQL.EntityFramework;
24

35
#region FiltersSignature
46

57
public class Filters
68
{
7-
public delegate bool Filter<in TEntity>(object userContext, TEntity input)
9+
public delegate bool Filter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
810
where TEntity : class;
911

10-
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, TEntity input)
12+
public delegate Task<bool> AsyncFilter<in TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity input)
1113
where TEntity : class;
1214

1315
#endregion
1416

1517
public void Add<TEntity>(Filter<TEntity> filter)
1618
where TEntity : class =>
1719
funcs[typeof(TEntity)] =
18-
(context, item) =>
20+
(userContext, userPrincipal, item) =>
1921
{
2022
try
2123
{
22-
return Task.FromResult(filter(context, (TEntity) item));
24+
return Task.FromResult(filter(userContext, userPrincipal, (TEntity) item));
2325
}
2426
catch (Exception exception)
2527
{
@@ -30,23 +32,23 @@ public void Add<TEntity>(Filter<TEntity> filter)
3032
public void Add<TEntity>(AsyncFilter<TEntity> filter)
3133
where TEntity : class =>
3234
funcs[typeof(TEntity)] =
33-
async (context, item) =>
35+
async (userContext, userPrincipal, item) =>
3436
{
3537
try
3638
{
37-
return await filter(context, (TEntity) item);
39+
return await filter(userContext, userPrincipal, (TEntity) item);
3840
}
3941
catch (Exception exception)
4042
{
4143
throw new($"Failed to execute filter. {nameof(TEntity)}: {typeof(TEntity)}.", exception);
4244
}
4345
};
4446

45-
delegate Task<bool> Filter(object userContext, object input);
47+
delegate Task<bool> Filter(object userContext, ClaimsPrincipal? userPrincipal, object input);
4648

4749
Dictionary<Type, Filter> funcs = new();
4850

49-
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext)
51+
internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal)
5052
where TEntity : class
5153
{
5254
if (funcs.Count == 0)
@@ -63,7 +65,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
6365
var list = new List<TEntity>();
6466
foreach (var item in result)
6567
{
66-
if (await ShouldInclude(userContext, item, filters))
68+
if (await ShouldInclude(userContext, userPrincipal, item, filters))
6769
{
6870
list.Add(item);
6971
}
@@ -72,12 +74,12 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerab
7274
return list;
7375
}
7476

75-
static async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity item, List<AsyncFilter<TEntity>> filters)
77+
static async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity item, List<AsyncFilter<TEntity>> filters)
7678
where TEntity : class
7779
{
7880
foreach (var func in filters)
7981
{
80-
if (!await func(userContext, item))
82+
if (!await func(userContext, userPrincipal, item))
8183
{
8284
return false;
8385
}
@@ -86,7 +88,7 @@ static async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity item,
8688
return true;
8789
}
8890

89-
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TEntity? item)
91+
internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
9092
where TEntity : class
9193
{
9294
if (item is null)
@@ -101,7 +103,7 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(object userContext, TEn
101103

102104
foreach (var func in FindFilters<TEntity>())
103105
{
104-
if (!await func(userContext, item))
106+
if (!await func(userContext, userPrincipal, item))
105107
{
106108
return false;
107109
}
@@ -116,7 +118,7 @@ IEnumerable<AsyncFilter<TEntity>> FindFilters<TEntity>()
116118
var type = typeof(TEntity);
117119
foreach (var pair in funcs.Where(x => x.Key.IsAssignableFrom(type)))
118120
{
119-
yield return (context, item) => pair.Value(context, item);
121+
yield return (userContext, userPrincipal, item) => pair.Value(userContext, userPrincipal, item);
120122
}
121123
}
122124
}
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
using GraphQL.EntityFramework;
1+
using System.Security.Claims;
2+
using GraphQL.EntityFramework;
23

34
class NullFilters :
45
Filters
56
{
67
public static NullFilters Instance = new();
78

8-
internal override Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext) =>
9+
internal override Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(IEnumerable<TEntity> result, object userContext, ClaimsPrincipal? userPrincipal) =>
910
Task.FromResult(result);
1011

11-
internal override Task<bool> ShouldInclude<TEntity>(object userContext, TEntity? item)
12+
internal override Task<bool> ShouldInclude<TEntity>(object userContext, ClaimsPrincipal? userPrincipal, TEntity? item)
1213
where TEntity : class =>
1314
Task.FromResult(true);
1415
}

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public FieldBuilder<TSource, TReturn> AddNavigationField<TSource, TReturn>(
3535
var fieldContext = BuildContext(context);
3636

3737
var result = resolve(fieldContext);
38-
if (await fieldContext.Filters.ShouldInclude(context.UserContext, result))
38+
if (await fieldContext.Filters.ShouldInclude(context.UserContext, context.User, result))
3939
{
4040
return result;
4141
}

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationConnection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ConnectionBuilder<TSource> AddEnumerableConnection<TSource, TGraph, TReturn>(
4949
var enumerable = resolve(efFieldContext);
5050

5151
enumerable = enumerable.ApplyGraphQlArguments(hasId, context);
52-
enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext);
52+
enumerable = await efFieldContext.Filters.ApplyFilter(enumerable, context.UserContext, context.User);
5353
var page = enumerable.ToList();
5454

5555
return ConnectionConverter.ApplyConnectionContext(

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_NavigationList.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public FieldBuilder<TSource, TReturn> AddNavigationListField<TSource, TReturn>(
3535
var fieldContext = BuildContext(context);
3636
var result = resolve(fieldContext);
3737
result = result.ApplyGraphQlArguments(hasId, context);
38-
return await fieldContext.Filters.ApplyFilter(result, context.UserContext);
38+
return await fieldContext.Filters.ApplyFilter(result, context.UserContext, context.User);
3939
});
4040
}
4141

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Queryable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ FieldType BuildQueryField<TSource, TReturn>(
7777
.ToListAsync(context.CancellationToken);
7878
}
7979

80-
return await fieldContext.Filters.ApplyFilter(list, context.UserContext);
80+
return await fieldContext.Filters.ApplyFilter(list, context.UserContext, context.User);
8181
});
8282
}
8383

src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Single.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ FieldType BuildSingleField<TSource, TReturn>(
9898

9999
if (single is not null)
100100
{
101-
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, single))
101+
if (await efFieldContext.Filters.ShouldInclude(context.UserContext, context.User, single))
102102
{
103103
if (mutate is not null)
104104
{

src/Snippets/GlobalFilterSnippets.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void Add(ServiceCollection services)
2020

2121
var filters = new Filters();
2222
filters.Add<MyEntity>(
23-
(userContext, item) => item.Property != "Ignore");
23+
(userContext, userPrincipal, item) => item.Property != "Ignore");
2424
EfGraphQLConventions.RegisterInContainer<MyDbContext>(
2525
services,
2626
resolveFilters: x => filters);

0 commit comments

Comments
 (0)