2929import static org .junit .Assert .fail ;
3030import static org .mockito .AdditionalAnswers .returnsFirstArg ;
3131import static org .mockito .Matchers .any ;
32- import static org .mockito .Matchers .anySet ;
32+ import static org .mockito .Matchers .anySetOf ;
3333import static org .mockito .Matchers .anyString ;
3434import static org .mockito .Mockito .atLeastOnce ;
3535import static org .mockito .Mockito .mock ;
5454import org .mitre .oauth2 .repository .OAuth2TokenRepository ;
5555import org .mitre .oauth2 .service .ClientDetailsEntityService ;
5656import org .mitre .oauth2 .service .SystemScopeService ;
57+ import org .mitre .openid .connect .service .ApprovedSiteService ;
5758import org .mockito .InjectMocks ;
5859import org .mockito .Matchers ;
5960import org .mockito .Mock ;
7576 *
7677 */
7778@ RunWith (MockitoJUnitRunner .class )
79+ @ SuppressWarnings ("deprecation" )
7880public class TestDefaultOAuth2ProviderTokenService {
7981
8082 // Grace period for time-sensitive tests.
8183 private static final long DELTA = 100L ;
8284
85+ private static final String clientId = "test_client" ;
86+ private static final String badClientId = "bad_client" ;
87+ private static final Set <String > scope =
88+ newHashSet ("openid" , "profile" , "email" , "offline_access" );
89+
8390 // Test Fixture:
84- private OAuth2Authentication authentication ;
8591 private ClientDetailsEntity client ;
8692 private ClientDetailsEntity badClient ;
87- private String clientId = "test_client" ;
88- private String badClientId = "bad_client" ;
89- private Set <String > scope = newHashSet ("openid" , "profile" , "email" , "offline_access" );
9093 private OAuth2RefreshTokenEntity refreshToken ;
9194 private OAuth2AccessTokenEntity accessToken ;
9295 private String refreshTokenValue = "refresh_token_value" ;
@@ -99,6 +102,9 @@ public class TestDefaultOAuth2ProviderTokenService {
99102 private AuthenticationHolderEntity storedAuthHolder ;
100103 private Set <String > storedScope ;
101104
105+ @ Mock
106+ private OAuth2Authentication authentication ;
107+
102108 @ Mock
103109 private OAuth2TokenRepository tokenRepository ;
104110
@@ -114,6 +120,9 @@ public class TestDefaultOAuth2ProviderTokenService {
114120 @ Mock
115121 private SystemScopeService scopeService ;
116122
123+ @ Mock
124+ private ApprovedSiteService approvedSiteService ;
125+
117126 @ InjectMocks
118127 private DefaultOAuth2ProviderTokenService service ;
119128
@@ -122,9 +131,10 @@ public class TestDefaultOAuth2ProviderTokenService {
122131 */
123132 @ Before
124133 public void prepare () {
125- reset (tokenRepository , authenticationHolderRepository , clientDetailsService , tokenEnhancer );
126134
127- authentication = Mockito .mock (OAuth2Authentication .class );
135+ reset (tokenRepository , authenticationHolderRepository , clientDetailsService , tokenEnhancer ,
136+ scopeService , approvedSiteService , authentication );
137+
128138 OAuth2Request clientAuth =
129139 new OAuth2Request (null , clientId , null , true , scope , null , null , null , null );
130140 when (authentication .getOAuth2Request ()).thenReturn (clientAuth );
@@ -165,21 +175,24 @@ public void prepare() {
165175 when (authenticationHolderRepository .save (any (AuthenticationHolderEntity .class )))
166176 .thenReturn (storedAuthHolder );
167177
168- when (scopeService .fromStrings (anySet ())).thenAnswer (new Answer <Set <SystemScope >>() {
169- @ Override
170- public Set <SystemScope > answer (InvocationOnMock invocation ) throws Throwable {
171- Object [] args = invocation .getArguments ();
172- Set <String > input = (Set <String >) args [0 ];
173- Set <SystemScope > output = new HashSet <>();
174- for (String scope : input ) {
175- output .add (new SystemScope (scope ));
178+ when (scopeService .fromStrings (anySetOf (String .class )))
179+ .thenAnswer (new Answer <Set <SystemScope >>() {
180+ @ Override
181+ @ SuppressWarnings ("unchecked" )
182+ public Set <SystemScope > answer (InvocationOnMock invocation ) throws Throwable {
183+ Object [] args = invocation .getArguments ();
184+ Set <String > input = (Set <String >) args [0 ];
185+ Set <SystemScope > output = new HashSet <>();
186+ for (String scope : input ) {
187+ output .add (new SystemScope (scope ));
188+ }
189+ return output ;
176190 }
177- return output ;
178- }
179- });
191+ });
180192
181- when (scopeService .toStrings (anySet ( ))).thenAnswer (new Answer <Set <String >>() {
193+ when (scopeService .toStrings (anySetOf ( SystemScope . class ))).thenAnswer (new Answer <Set <String >>() {
182194 @ Override
195+ @ SuppressWarnings ("unchecked" )
183196 public Set <String > answer (InvocationOnMock invocation ) throws Throwable {
184197 Object [] args = invocation .getArguments ();
185198 Set <SystemScope > input = (Set <SystemScope >) args [0 ];
@@ -191,19 +204,22 @@ public Set<String> answer(InvocationOnMock invocation) throws Throwable {
191204 }
192205 });
193206
194- when (scopeService .scopesMatch (anySet (), anySet ())).thenAnswer (new Answer <Boolean >() {
195- @ Override
196- public Boolean answer (InvocationOnMock invocation ) throws Throwable {
197- Object [] args = invocation .getArguments ();
198- Set <String > expected = (Set <String >) args [0 ];
199- Set <String > actual = (Set <String >) args [1 ];
200- return expected .containsAll (actual );
201- }
202- });
203-
207+ when (scopeService .scopesMatch (anySetOf (String .class ), anySetOf (String .class )))
208+ .thenAnswer (new Answer <Boolean >() {
209+ @ Override
210+ @ SuppressWarnings ("unchecked" )
211+ public Boolean answer (InvocationOnMock invocation ) throws Throwable {
212+ Object [] args = invocation .getArguments ();
213+ Set <String > expected = (Set <String >) args [0 ];
214+ Set <String > actual = (Set <String >) args [1 ];
215+ return expected .containsAll (actual );
216+ }
217+ });
218+
204219 // we're not testing restricted or reserved scopes here, just pass through
205- when (scopeService .removeReservedScopes (anySet ())).then (returnsFirstArg ());
206- when (scopeService .removeRestrictedAndReservedScopes (anySet ())).then (returnsFirstArg ());
220+ when (scopeService .removeReservedScopes (anySetOf (SystemScope .class ))).then (returnsFirstArg ());
221+ when (scopeService .removeRestrictedAndReservedScopes (anySetOf (SystemScope .class )))
222+ .then (returnsFirstArg ());
207223
208224 when (tokenEnhancer .enhance (any (OAuth2AccessTokenEntity .class ), any (OAuth2Authentication .class )))
209225 .thenAnswer (new Answer <OAuth2AccessTokenEntity >() {
@@ -281,7 +297,7 @@ public void createAccessToken_noRefresh() {
281297 verify (authenticationHolderRepository ).save (any (AuthenticationHolderEntity .class ));
282298 verify (tokenEnhancer ).enhance (any (OAuth2AccessTokenEntity .class ), Matchers .eq (authentication ));
283299 verify (tokenRepository ).saveAccessToken (any (OAuth2AccessTokenEntity .class ));
284- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
300+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
285301
286302 verify (tokenRepository , Mockito .never ()).saveRefreshToken (any (OAuth2RefreshTokenEntity .class ));
287303
@@ -303,7 +319,7 @@ public void createAccessToken_yesRefresh() {
303319 // Note: a refactor may be appropriate to only save refresh tokens once to the repository during
304320 // creation.
305321 verify (tokenRepository , atLeastOnce ()).saveRefreshToken (any (OAuth2RefreshTokenEntity .class ));
306- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
322+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
307323
308324 assertThat (token .getRefreshToken (), is (notNullValue ()));
309325 }
@@ -330,7 +346,7 @@ public void createAccessToken_expiration() {
330346 Date lowerBoundRefreshTokens = new Date (start + (refreshTokenValiditySeconds * 1000L ) - DELTA );
331347 Date upperBoundRefreshTokens = new Date (end + (refreshTokenValiditySeconds * 1000L ) + DELTA );
332348
333- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
349+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
334350
335351 assertTrue (token .getExpiration ().after (lowerBoundAccessTokens )
336352 && token .getExpiration ().before (upperBoundAccessTokens ));
@@ -342,7 +358,7 @@ public void createAccessToken_expiration() {
342358 public void createAccessToken_checkClient () {
343359 OAuth2AccessTokenEntity token = service .createAccessToken (authentication );
344360
345- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
361+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
346362
347363 assertThat (token .getClient ().getClientId (), equalTo (clientId ));
348364 }
@@ -351,7 +367,7 @@ public void createAccessToken_checkClient() {
351367 public void createAccessToken_checkScopes () {
352368 OAuth2AccessTokenEntity token = service .createAccessToken (authentication );
353369
354- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
370+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
355371
356372 assertThat (token .getScope (), equalTo (scope ));
357373 }
@@ -368,7 +384,7 @@ public void createAccessToken_checkAttachedAuthentication() {
368384
369385 assertThat (token .getAuthenticationHolder ().getAuthentication (), equalTo (authentication ));
370386 verify (authenticationHolderRepository ).save (any (AuthenticationHolderEntity .class ));
371- verify (scopeService , atLeastOnce ()).removeReservedScopes (anySet ( ));
387+ verify (scopeService , atLeastOnce ()).removeReservedScopes (anySetOf ( SystemScope . class ));
372388 }
373389
374390 @ Test (expected = InvalidTokenException .class )
0 commit comments