diff --git a/api/src/main/resources/example.application.yaml b/api/src/main/resources/example.application.yaml index 18092bc..b6f3aed 100644 --- a/api/src/main/resources/example.application.yaml +++ b/api/src/main/resources/example.application.yaml @@ -98,6 +98,43 @@ loginsvc: # ldapFieldName: claimFieldName mail: "email" displayname: "displayname" + # MS Entra (Azure AD) Bearer token authentication provider. + # Users with a valid Entra access token can exchange it for a login-service JWT. + #entra: + # Set the order of the protocol starting from 1 + # Set to 0 to disable or simply exclude the entra tag from config + # NOTE: At least 1 auth protocol needs to be enabled + #order: 0 + # Azure AD tenant ID (directory ID) + #tenant-id: "your-tenant-id" + # Application (client) ID registered in Entra + #client-id: "your-client-id" + # Client secret used to call MS Graph API for on-premises username resolution. + # When set, the authenticated user's UPN is exchanged for their lower-case + # samAccountName + # via Graph API. Omit to use the UPN from the token directly. + #client-secret: "your-client-secret" + # Accepted JWT 'aud' claim values — tokens from any listed application are accepted; + # use an empty list to accept any token issued by the configured tenant + #audiences: + #- "api://your-client-id" + #- "other-app-client-id" + # Mapping from on-premises DNS domain names to NetBIOS short names. + # Required when client-secret is set and users have on-premises AD accounts. + # These mapped values are used to allow known domains and log the mapped AB value. + #domains: + #corp.example.com: "CORP" + #another.domain.com: "ANOTHER" + # Base URL for Microsoft login/token endpoints. + # Defaults to the public Azure cloud. Override for sovereign clouds. + #login-base-url: "https://login.microsoftonline.com" + # Base URL for the Microsoft Graph API. + # Defaults to the public Azure cloud. Override for sovereign clouds. + #graph-base-url: "https://graph.microsoft.com" + # Optional mapping from Entra JWT claim names to LS JWT claim names + #attributes: + #preferred_username: "upn" + #email: "email" experimental: # ability to enable experimental endpoints (default=false) diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala index ead04ed..e204ff7 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/SecurityConfig.scala @@ -27,6 +27,7 @@ import org.springframework.security.web.authentication.www.BasicAuthenticationFi import org.springframework.security.web.{AuthenticationEntryPoint, SecurityFilterChain} import za.co.absa.loginsvc.rest.config.provider.AuthConfigProvider import za.co.absa.loginsvc.rest.provider.ad.ldap.LdapConnectionException +import za.co.absa.loginsvc.rest.provider.entra.{MsEntraBearerTokenFilter, MsEntraTokenValidator} import za.co.absa.loginsvc.rest.provider.kerberos.KerberosSPNEGOAuthenticationProvider import javax.servlet.http.{HttpServletRequest, HttpServletResponse} @@ -39,6 +40,8 @@ class SecurityConfig @Autowired()(authConfigsProvider: AuthConfigProvider, authM private val ldapConfig = authConfigsProvider.getLdapConfig.orNull private val isKerberosEnabled = authConfigsProvider.getLdapConfig.exists(_.enableKerberos.isDefined) + private val msEntraConfig = authConfigsProvider.getMsEntraConfig + private val isMsEntraEnabled = msEntraConfig.exists(_.order > 0) @Bean @@ -76,6 +79,11 @@ class SecurityConfig @Autowired()(authConfigsProvider: AuthConfigProvider, authM classOf[BasicAuthenticationFilter]) } + if (isMsEntraEnabled) { + val entraFilter = new MsEntraBearerTokenFilter(MsEntraTokenValidator(msEntraConfig.get)) + http.addFilterBefore(entraFilter, classOf[BasicAuthenticationFilter]) + } + http.build() } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfig.scala new file mode 100644 index 0000000..8c80f85 --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfig.scala @@ -0,0 +1,75 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.config.auth + +import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} +import za.co.absa.loginsvc.rest.config.validation.{ConfigValidatable, ConfigValidationException, ConfigValidationResult} + +/** + * Configuration for MS Entra (Azure AD) Bearer token authentication provider. + * + * @param tenantId Azure AD tenant ID (directory ID) + * @param clientId Application (client) ID registered in Entra + * @param clientSecret Client secret used to acquire a Graph API token for username resolution. + * When set, the token's `preferred_username` (UPN) is exchanged for + * `onPremisesSamAccountName` via MS Graph, and the resulting username + * is formatted as lower-case `samAccountName`. + * @param audiences Accepted JWT 'aud' claim values — tokens from any listed app are accepted; + * empty list accepts any token from the tenant + * @param domains Mapping from on-premises DNS domain names to their NetBIOS short names, + * e.g. `corp.example.com -> CORP`. Required when `clientSecret` is set + * so known domains can be allowed and their mapped AB values logged. + * @param order Provider ordering (0 = disabled, 1+ = active) + * @param attributes Optional mapping from Entra JWT claim names to LS JWT claim names + * @param loginBaseUrl Base URL for Microsoft login/token endpoints. + * Defaults to the public Azure cloud (`https://login.microsoftonline.com`). + * Override for sovereign clouds (e.g. Azure Government). + * @param graphBaseUrl Base URL for the Microsoft Graph API. + * Defaults to the public Azure cloud (`https://graph.microsoft.com`). + * Override for sovereign clouds (e.g. Azure Government). + */ +case class MsEntraConfig( + tenantId: String, + clientId: String, + clientSecret: Option[String] = None, + audiences: List[String], + domains: Option[Map[String, String]] = None, + order: Int, + attributes: Option[Map[String, String]], + loginBaseUrl: String = "https://login.microsoftonline.com", + graphBaseUrl: String = "https://graph.microsoft.com" +) extends ConfigValidatable with ConfigOrdering { + + def throwErrors(): Unit = + this.validate().throwOnErrors() + + override def validate(): ConfigValidationResult = { + if (order > 0) { + val results = Seq( + Option(tenantId) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("tenantId is empty"))), + + Option(clientId) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("clientId is empty"))) + ) + + results.foldLeft[ConfigValidationResult](ConfigValidationSuccess)(ConfigValidationResult.merge) + } else ConfigValidationSuccess + } +} diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/AuthConfigProvider.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/AuthConfigProvider.scala index cc0a067..86f91f4 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/AuthConfigProvider.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/AuthConfigProvider.scala @@ -21,4 +21,5 @@ import za.co.absa.loginsvc.rest.config.auth._ trait AuthConfigProvider { def getLdapConfig : Option[ActiveDirectoryLDAPConfig] def getUsersConfig : Option[UsersConfig] + def getMsEntraConfig : Option[MsEntraConfig] } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProvider.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProvider.scala index aa871a6..71be7bd 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProvider.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProvider.scala @@ -87,6 +87,14 @@ class ConfigProvider(@Value("${spring.config.location}") yamlPath: String) userConfigOption } + def getMsEntraConfig: Option[MsEntraConfig] = { + val entraConfigOption = createConfigClass[MsEntraConfig]("loginsvc.rest.auth.provider.entra") + if (entraConfigOption.nonEmpty) + entraConfigOption.get.throwErrors() + + entraConfigOption + } + private def getGitConfig: GitConfig = { createConfigClass[GitConfig]("loginsvc.rest.config.git-info"). getOrElse(GitConfig(generateGitProperties = false, generateGitPropertiesFile = false)) diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilter.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilter.scala new file mode 100644 index 0000000..49755f0 --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilter.scala @@ -0,0 +1,88 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import org.slf4j.LoggerFactory +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken +import org.springframework.security.core.authority.SimpleGrantedAuthority +import org.springframework.security.core.context.SecurityContextHolder +import org.springframework.web.filter.OncePerRequestFilter +import za.co.absa.loginsvc.model.User + +import javax.servlet.{FilterChain, ServletRequest, ServletResponse} +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} + +/** + * Spring Security filter that intercepts requests carrying an MS Entra Bearer token. + * + * When an `Authorization: Bearer ` header is present and no authentication + * is already established, delegates to [[MsEntraTokenValidator]] to validate the token + * and populate the [[SecurityContextHolder]]. + * + * On invalid tokens the request is rejected with HTTP 401. + * On missing Bearer header the filter passes through, allowing other filters (e.g. + * BasicAuth) to handle authentication. + */ +class MsEntraBearerTokenFilter(validator: MsEntraTokenValidator) extends OncePerRequestFilter { + + private val log = LoggerFactory.getLogger(classOf[MsEntraBearerTokenFilter]) + + private val BearerPrefix = "Bearer " + + override def doFilterInternal( + request: HttpServletRequest, + response: HttpServletResponse, + filterChain: FilterChain + ): Unit = { + val authHeader = Option(request.getHeader("Authorization")) + + authHeader match { + case Some(header) if header.startsWith(BearerPrefix) => + // Only process if SecurityContext is not already populated + if (SecurityContextHolder.getContext.getAuthentication != null) { + filterChain.doFilter(request, response) + } else { + val rawToken = header.substring(BearerPrefix.length).trim + validator.validate(rawToken) match { + case Success(user) => + log.info(s"Entra-based: Login of user ${user.name} - ok") + setAuthentication(user) + filterChain.doFilter(request, response) + + case Failure(ex) => + log.warn(s"Entra Bearer token rejected: ${ex.getMessage}") + SecurityContextHolder.clearContext() + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setContentType("application/json") + response.getWriter.write(s"""{"error": "Invalid or expired Entra token"}""") + } + } + + case _ => + // No Bearer header — pass through to allow other auth mechanisms + filterChain.doFilter(request, response) + } + } + + private def setAuthentication(user: User): Unit = { + val authorities = user.groups.map(new SimpleGrantedAuthority(_)).asJava + val authentication = new UsernamePasswordAuthenticationToken(user, null, authorities) + SecurityContextHolder.getContext.setAuthentication(authentication) + } +} diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClient.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClient.scala new file mode 100644 index 0000000..809fdec --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClient.scala @@ -0,0 +1,183 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.slf4j.LoggerFactory +import za.co.absa.loginsvc.rest.config.auth.MsEntraConfig + +import java.io.{DataOutputStream, InputStream} +import java.net.{HttpURLConnection, URL, URLEncoder} +import java.util.Locale +import java.util.concurrent.TimeUnit +import scala.io.Source +import scala.util.{Failure, Success, Try} + +/** + * Resolves a username via the MS Graph API by looking up the user's on-premises SAM account name. + * + * Returns `Some("samaccountname")` when on-premises AD attributes are present, + * or `None` to signal the caller to fall back to the UPN from the token. + */ +trait GraphUsernameResolver { + def resolveUsername(upn: String): Option[String] +} + +/** + * MS Graph-based implementation of [[GraphUsernameResolver]]. + * + * Acquires an access token for the Graph API via client credentials (clientId + clientSecret), + * then queries `GET /v1.0/users/{upn}?$select=onPremisesSamAccountName,onPremisesDomainName`. + * The DNS domain name must exist in the `domains` config map; its mapped AB/NetBIOS value is + * logged, and the result is returned as lower-case `samAccountName`. + * + * Falls back to `None` (i.e., use UPN) when: + * - the user has no on-premises AD attributes (e.g. cloud-only or external-tenant users) + * - the `onPremisesDomainName` is not in the `domains` map + * - any HTTP or parsing error occurs + * + * The Graph access token is cached for 50 minutes. + * + * @param config Entra config — must have `clientSecret` set; `domains` maps DNS domain → + * AB/NetBIOS short name for allow-listing and logging. + */ +class MsEntraGraphClient( + config: MsEntraConfig +) extends GraphUsernameResolver { + + private val logger = LoggerFactory.getLogger(classOf[MsEntraGraphClient]) + + private val tokenEndpoint = + s"${config.loginBaseUrl}/${config.tenantId}/oauth2/v2.0/token" + + private val graphUsersBaseUrl = s"${config.graphBaseUrl}/v1.0/users" + + private val domainMap: Map[String, String] = config.domains.getOrElse(Map.empty) + + // Cache the Graph access token; expires well before the typical 1-hour token lifetime + private val accessTokenCache: LoadingCache[String, String] = + CacheBuilder.newBuilder() + .expireAfterWrite(50, TimeUnit.MINUTES) + .build(new CacheLoader[String, String] { + override def load(key: String): String = fetchAccessToken() + }) + + override def resolveUsername(upn: String): Option[String] = { + Try { + val accessToken = accessTokenCache.get("token") + val (samOpt, domainOpt) = queryGraphForUser(accessToken, upn) + + (samOpt.filter(_.nonEmpty), domainOpt.filter(_.nonEmpty)) match { + case (Some(sam), Some(dnsDomain)) => + resolveMappedUsername(sam, dnsDomain, upn) + + case _ => + logger.debug(s"User '$upn' has no on-premises AD attributes; using UPN as username") + None + } + } match { + case Success(result) => result + case Failure(e) => + logger.warn(s"Graph API lookup failed for '$upn': ${e.getMessage}") + None + } + } + + private[entra] def resolveMappedUsername(sam: String, dnsDomain: String, upn: String): Option[String] = { + domainMap.get(dnsDomain) match { + case Some(netbios) => + val normalizedSam = sam.toLowerCase(Locale.ROOT) + logger.debug( + s"Resolved user '$upn' to '$normalizedSam' via Graph API " + + s"(mapped AB value '$netbios' from domain '$dnsDomain')" + ) + Some(normalizedSam) + case None => + logger.error( + s"Unknown onPremisesDomainName '$dnsDomain' for user '$upn'. " + + "Add it to the 'domains' mapping in the Entra config. Falling back to UPN." + ) + None + } + } + + private def fetchAccessToken(): String = { + val secret = config.clientSecret.getOrElse( + throw new IllegalStateException("clientSecret is required to call the Graph API") + ) + val body = Seq( + "grant_type" -> "client_credentials", + "client_id" -> config.clientId, + "client_secret" -> secret, + "scope" -> s"${config.graphBaseUrl}/.default" + ).map { case (k, v) => URLEncoder.encode(k, "UTF-8") + "=" + URLEncoder.encode(v, "UTF-8") } + .mkString("&") + + val responseJson = httpPost(tokenEndpoint, body) + parseJsonStringField(responseJson, "access_token") + .getOrElse(throw new IllegalStateException("No access_token in token endpoint response")) + } + + private def queryGraphForUser(accessToken: String, upn: String): (Option[String], Option[String]) = { + val encodedUpn = URLEncoder.encode(upn, "UTF-8") + val url = s"$graphUsersBaseUrl/$encodedUpn?$$select=onPremisesSamAccountName,onPremisesDomainName" + val responseJson = httpGet(url, accessToken) + val sam = parseJsonStringField(responseJson, "onPremisesSamAccountName") + val domain = parseJsonStringField(responseJson, "onPremisesDomainName") + (sam, domain) + } + + private def httpPost(urlStr: String, body: String): String = { + val conn = new URL(urlStr).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setDoOutput(true) + conn.setConnectTimeout(5000) + conn.setReadTimeout(5000) + conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") + val bytes = body.getBytes("UTF-8") + conn.setRequestProperty("Content-Length", bytes.length.toString) + val out = new DataOutputStream(conn.getOutputStream) + try { out.write(bytes) } finally { out.close() } + readResponse(conn) + } + + private def httpGet(urlStr: String, bearerToken: String): String = { + val conn = new URL(urlStr).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + conn.setConnectTimeout(5000) + conn.setReadTimeout(5000) + conn.setRequestProperty("Authorization", s"Bearer $bearerToken") + conn.setRequestProperty("ConsistencyLevel", "eventual") + readResponse(conn) + } + + private def readResponse(conn: HttpURLConnection): String = { + val status = conn.getResponseCode + val stream: InputStream = + if (status >= 200 && status < 300) conn.getInputStream else conn.getErrorStream + val body = Source.fromInputStream(stream, "UTF-8").mkString + if (status >= 200 && status < 300) body + else throw new RuntimeException(s"HTTP $status from ${conn.getURL.getHost}: $body") + } + + /** Extracts a string-valued field from a flat JSON object, returning None if absent or null. */ + private def parseJsonStringField(json: String, fieldName: String): Option[String] = { + val escapedName = java.util.regex.Pattern.quote(fieldName) + val pattern = ("\"" + escapedName + "\"\\s*:\\s*\"([^\"]+)\"").r + pattern.findFirstMatchIn(json).map(_.group(1)) + } +} diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidator.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidator.scala new file mode 100644 index 0000000..c983135 --- /dev/null +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidator.scala @@ -0,0 +1,164 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.nimbusds.jose.JWSAlgorithm +import com.nimbusds.jose.jwk.source.{JWKSource, RemoteJWKSet} +import com.nimbusds.jose.proc.{JWSVerificationKeySelector, SecurityContext => NimbusSecurityContext} +import com.nimbusds.jwt.proc.{BadJWTException, DefaultJWTClaimsVerifier, DefaultJWTProcessor} +import com.nimbusds.jwt.{JWTClaimsSet, SignedJWT} +import org.slf4j.LoggerFactory +import za.co.absa.loginsvc.model.User +import za.co.absa.loginsvc.rest.config.auth.MsEntraConfig + +import java.net.URL +import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + +/** + * Validates MS Entra (Azure AD) Bearer JWT tokens. + * + * Fetches the JWKS URI from Microsoft's OIDC discovery endpoint and caches it. + * Validates the token's signature, expiry, issuer and audience, then extracts a [[User]]. + * + * The discovery URL follows the Microsoft standard: + * https://login.microsoftonline.com/{tenantId}/v2.0/.well-known/openid-configuration + * + * @param config Entra configuration + * @param jwkSourceOverride Optional override for the JWK source (used in tests to avoid HTTP calls) + * @param graphClientOverride Optional override for the Graph username resolver (used in tests) + */ +class MsEntraTokenValidator( + config: MsEntraConfig, + private[entra] val jwkSourceOverride: Option[JWKSource[NimbusSecurityContext]] = None, + private[entra] val graphClientOverride: Option[GraphUsernameResolver] = None +) { + + private val logger = LoggerFactory.getLogger(classOf[MsEntraTokenValidator]) + + // Use the injected graph resolver (for tests) or create one from config if clientSecret is set + private val graphResolver: Option[GraphUsernameResolver] = + graphClientOverride.orElse(config.clientSecret.map(_ => new MsEntraGraphClient(config))) + + private val discoveryUrl = + s"${config.loginBaseUrl}/${config.tenantId}/v2.0/.well-known/openid-configuration" + + private val expectedIssuer = + s"${config.loginBaseUrl}/${config.tenantId}/v2.0" + + // Cache the JWKSource keyed by jwks_uri string; refreshes after 1 hour + private val jwkSourceCache: LoadingCache[String, JWKSource[NimbusSecurityContext]] = + CacheBuilder.newBuilder() + .expireAfterWrite(1, TimeUnit.HOURS) + .build(new CacheLoader[String, JWKSource[NimbusSecurityContext]] { + override def load(jwksUri: String): JWKSource[NimbusSecurityContext] = { + logger.info(s"Loading JWKS from $jwksUri") + new RemoteJWKSet[NimbusSecurityContext](new URL(jwksUri)) + } + }) + + /** + * Validates the given raw Entra JWT string. + * + * @param rawToken the Bearer token string (without "Bearer " prefix) + * @return a [[User]] if the token is valid, or a Failure with a descriptive exception + */ + def validate(rawToken: String): Try[User] = { + Try { + val jwkSource = jwkSourceOverride.getOrElse { + val jwksUri = resolveJwksUri() + jwkSourceCache.get(jwksUri) + } + + val jwtProcessor = new DefaultJWTProcessor[NimbusSecurityContext]() + val keySelector = new JWSVerificationKeySelector[NimbusSecurityContext]( + JWSAlgorithm.RS256, + jwkSource + ) + jwtProcessor.setJWSKeySelector(keySelector) + + // Verify signature, issuer, expiry and required claims + val requiredClaims = new DefaultJWTClaimsVerifier[NimbusSecurityContext]( + new JWTClaimsSet.Builder().issuer(expectedIssuer).build(), + Set("sub", "iat", "exp").asJava + ) + jwtProcessor.setJWTClaimsSetVerifier(requiredClaims) + + val claims: JWTClaimsSet = jwtProcessor.process(rawToken, null) + + // Audience check: if audiences are configured, token must contain at least one + if (config.audiences.nonEmpty) { + val tokenAudiences = Option(claims.getAudience).map(_.asScala.toSet).getOrElse(Set.empty) + val configAudiences = config.audiences.toSet + if (tokenAudiences.intersect(configAudiences).isEmpty) + throw new BadJWTException( + s"JWT aud claim has value $tokenAudiences, must include one of $configAudiences" + ) + } + + extractUser(claims) + } recoverWith { + case e: BadJWTException => + logger.warn(s"Entra token validation failed (claims): ${e.getMessage}") + Failure(e) + case e: Exception => + logger.warn(s"Entra token validation failed: ${e.getMessage}") + Failure(e) + } + } + + private def extractUser(claims: JWTClaimsSet): User = { + val rawUsername = Option(claims.getStringClaim("preferred_username")) + .orElse(Option(claims.getStringClaim("upn"))) + .orElse(Option(claims.getSubject)) + .getOrElse(throw new IllegalArgumentException("Entra token has no usable username claim (preferred_username/upn/sub)")) + + // Attempt to resolve to lower-case on-premises samAccountName via Graph API; fall back to UPN + val username = graphResolver.flatMap(_.resolveUsername(rawUsername)).getOrElse(rawUsername) + + val groups: Seq[String] = Option(claims.getStringListClaim("groups")) + .map(_.asScala.toSeq) + .getOrElse(Seq.empty) + + val optionalAttributes: Map[String, Option[AnyRef]] = config.attributes.getOrElse(Map.empty).flatMap { + case (claimName, lsClaimName) => + Option(claims.getClaim(claimName)).map { value => + lsClaimName -> Some(value.asInstanceOf[AnyRef]) + } + } + + User(username, groups, optionalAttributes) + } + + private def resolveJwksUri(): String = { + val conn = new URL(discoveryUrl).openConnection() + conn.setConnectTimeout(5000) + conn.setReadTimeout(5000) + val json = scala.io.Source.fromInputStream(conn.getInputStream).mkString + // Simple string extraction without pulling in additional JSON libraries + val jwksUriPattern = """"jwks_uri"\s*:\s*"([^"]+)"""".r + jwksUriPattern.findFirstMatchIn(json) + .map(_.group(1)) + .getOrElse(throw new IllegalStateException(s"Could not find jwks_uri in OIDC discovery doc at $discoveryUrl")) + } +} + +object MsEntraTokenValidator { + def apply(config: MsEntraConfig): MsEntraTokenValidator = new MsEntraTokenValidator(config) +} diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfigTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfigTest.scala new file mode 100644 index 0000000..86f1a84 --- /dev/null +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/auth/MsEntraConfigTest.scala @@ -0,0 +1,78 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.config.auth + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException +import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} + +class MsEntraConfigTest extends AnyFlatSpec with Matchers { + + private val validConfig = MsEntraConfig( + tenantId = "test-tenant-id", + clientId = "test-client-id", + audiences = List("api://test-client-id", "other-app-client-id"), + order = 1, + attributes = Some(Map("preferred_username" -> "upn", "email" -> "email")) + ) + + "MsEntraConfig" should "validate expected filled content" in { + validConfig.validate() shouldBe ConfigValidationSuccess + } + + it should "validate with no attributes (they are optional)" in { + validConfig.copy(attributes = None).validate() shouldBe ConfigValidationSuccess + } + + it should "validate with empty attributes" in { + validConfig.copy(attributes = Some(Map.empty)).validate() shouldBe ConfigValidationSuccess + } + + it should "fail on null tenantId" in { + val result = validConfig.copy(tenantId = null).validate() + result shouldBe ConfigValidationError(ConfigValidationException("tenantId is empty")) + } + + it should "fail on null clientId" in { + val result = validConfig.copy(clientId = null).validate() + result shouldBe ConfigValidationError(ConfigValidationException("clientId is empty")) + } + + it should "pass validation with empty audiences (accept any token from the tenant)" in { + validConfig.copy(audiences = List.empty).validate() shouldBe ConfigValidationSuccess + } + + it should "accumulate multiple validation errors" in { + val result = validConfig.copy(tenantId = null, clientId = null).validate() + result shouldBe a[ConfigValidationError] + result.errors should have size 2 + result.errors.map(_.msg) should contain allOf ("tenantId is empty", "clientId is empty") + } + + it should "pass validation when disabled (order=0) even with empty fields" in { + MsEntraConfig(tenantId = null, clientId = null, audiences = List.empty, order = 0, attributes = None) + .validate() shouldBe ConfigValidationSuccess + } + + it should "throw on throwErrors() when invalid" in { + val exception = intercept[ConfigValidationException] { + validConfig.copy(tenantId = null).throwErrors() + } + exception.msg should include("tenantId is empty") + } +} diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilterTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilterTest.scala new file mode 100644 index 0000000..10a3d19 --- /dev/null +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraBearerTokenFilterTest.scala @@ -0,0 +1,140 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import org.mockito.ArgumentMatchers.anyString +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.springframework.mock.web.{MockFilterChain, MockHttpServletRequest, MockHttpServletResponse} +import org.springframework.security.core.context.SecurityContextHolder +import za.co.absa.loginsvc.model.User + +import javax.servlet.http.HttpServletResponse +import scala.util.{Failure, Success} + +class MsEntraBearerTokenFilterTest extends AnyFlatSpec with Matchers with BeforeAndAfterEach { + + private val fakeUser = User("user@example.com", Seq("group1", "group2"), Map.empty) + + private val mockValidator = mock(classOf[MsEntraTokenValidator]) + private val filter = new MsEntraBearerTokenFilter(mockValidator) + + override def beforeEach(): Unit = { + SecurityContextHolder.clearContext() + } + + override def afterEach(): Unit = { + SecurityContextHolder.clearContext() + } + + "MsEntraBearerTokenFilter" should "authenticate and pass through on a valid Bearer token" in { + when(mockValidator.validate(anyString())).thenReturn(Success(fakeUser)) + + val request = new MockHttpServletRequest() + request.addHeader("Authorization", "Bearer valid.entra.token") + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + response.getStatus shouldBe HttpServletResponse.SC_OK + val auth = SecurityContextHolder.getContext.getAuthentication + auth should not be null + auth.getPrincipal shouldBe fakeUser + chain.getRequest should not be null // filter chain was called + } + + it should "return 401 and not call filter chain on an invalid Bearer token" in { + when(mockValidator.validate(anyString())).thenReturn(Failure(new Exception("Bad token"))) + + val request = new MockHttpServletRequest() + request.addHeader("Authorization", "Bearer invalid.token") + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + response.getStatus shouldBe HttpServletResponse.SC_UNAUTHORIZED + response.getContentType shouldBe "application/json" + response.getContentAsString should include("Invalid or expired Entra token") + chain.getRequest shouldBe null // filter chain was NOT called + } + + it should "pass through without calling the validator when no Authorization header is present" in { + val request = new MockHttpServletRequest() + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + response.getStatus shouldBe HttpServletResponse.SC_OK + SecurityContextHolder.getContext.getAuthentication shouldBe null + chain.getRequest should not be null + } + + it should "pass through without calling the validator when Authorization header is not a Bearer token" in { + val request = new MockHttpServletRequest() + request.addHeader("Authorization", "Basic dXNlcjpwYXNzd29yZA==") + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + response.getStatus shouldBe HttpServletResponse.SC_OK + SecurityContextHolder.getContext.getAuthentication shouldBe null + chain.getRequest should not be null + } + + it should "skip validation and pass through when SecurityContext is already authenticated" in { + when(mockValidator.validate(anyString())).thenReturn(Success(fakeUser)) + + // Pre-populate the security context as if another filter already authenticated + val preExistingAuth = new org.springframework.security.authentication.UsernamePasswordAuthenticationToken( + "already-authenticated-user", "creds", + new java.util.ArrayList[org.springframework.security.core.GrantedAuthority]() + ) + SecurityContextHolder.getContext.setAuthentication(preExistingAuth) + + val request = new MockHttpServletRequest() + request.addHeader("Authorization", "Bearer some.token") + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + // Authentication should remain the pre-existing one + SecurityContextHolder.getContext.getAuthentication.getPrincipal shouldBe "already-authenticated-user" + chain.getRequest should not be null + } + + it should "populate groups as Spring authorities" in { + when(mockValidator.validate(anyString())).thenReturn(Success(fakeUser)) + + val request = new MockHttpServletRequest() + request.addHeader("Authorization", "Bearer valid.entra.token") + val response = new MockHttpServletResponse() + val chain = new MockFilterChain() + + filter.doFilter(request, response, chain) + + import scala.collection.JavaConverters._ + val authorities = SecurityContextHolder.getContext.getAuthentication.getAuthorities.asScala.map(_.getAuthority) + authorities should contain allOf ("group1", "group2") + } +} diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClientTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClientTest.scala new file mode 100644 index 0000000..388998a --- /dev/null +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraGraphClientTest.scala @@ -0,0 +1,151 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.loginsvc.rest.config.auth.MsEntraConfig + +import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets + +class MsEntraGraphClientTest extends AnyFlatSpec with Matchers with BeforeAndAfterAll with BeforeAndAfterEach { + + @volatile private var tokenStatus = 200 + @volatile private var tokenResponseBody = """{"access_token":"graph-token"}""" + @volatile private var graphStatus = 200 + @volatile private var graphResponseBody = """{"onPremisesSamAccountName":"UN123XY","onPremisesDomainName":"corp.dsarena.com"}""" + @volatile private var lastTokenRequestBody = "" + @volatile private var lastGraphAuthorization = "" + @volatile private var lastGraphPath = "" + @volatile private var lastGraphRawPath = "" + @volatile private var lastGraphQuery = "" + + private var server: HttpServer = _ + private var baseUrl: String = _ + + override protected def beforeAll(): Unit = { + super.beforeAll() + server = HttpServer.create(new InetSocketAddress(0), 0) + server.createContext("/tenant-id/oauth2/v2.0/token", new HttpHandler { + override def handle(exchange: HttpExchange): Unit = { + lastTokenRequestBody = new String(exchange.getRequestBody.readAllBytes(), StandardCharsets.UTF_8) + respond(exchange, tokenStatus, tokenResponseBody) + } + }) + server.createContext("/v1.0/users", new HttpHandler { + override def handle(exchange: HttpExchange): Unit = { + lastGraphAuthorization = Option(exchange.getRequestHeaders.getFirst("Authorization")).getOrElse("") + lastGraphPath = exchange.getRequestURI.getPath + lastGraphRawPath = exchange.getRequestURI.getRawPath + lastGraphQuery = Option(exchange.getRequestURI.getRawQuery).getOrElse("") + respond(exchange, graphStatus, graphResponseBody) + } + }) + server.start() + baseUrl = s"http://127.0.0.1:${server.getAddress.getPort}" + } + + override protected def afterAll(): Unit = { + if (server != null) server.stop(0) + super.afterAll() + } + + override protected def beforeEach(): Unit = { + tokenStatus = 200 + tokenResponseBody = """{"access_token":"graph-token"}""" + graphStatus = 200 + graphResponseBody = """{"onPremisesSamAccountName":"UN123XY","onPremisesDomainName":"corp.dsarena.com"}""" + lastTokenRequestBody = "" + lastGraphAuthorization = "" + lastGraphPath = "" + lastGraphRawPath = "" + lastGraphQuery = "" + super.beforeEach() + } + + private def respond(exchange: HttpExchange, status: Int, body: String): Unit = { + val bytes = body.getBytes(StandardCharsets.UTF_8) + exchange.sendResponseHeaders(status, bytes.length.toLong) + val os = exchange.getResponseBody + try os.write(bytes) finally os.close() + } + + private def client( + secret: Option[String] = Some("test-secret"), + domains: Option[Map[String, String]] = Some(Map("corp.dsarena.com" -> "CORP")) + ): MsEntraGraphClient = + new MsEntraGraphClient( + MsEntraConfig( + tenantId = "tenant-id", + clientId = "client-id", + clientSecret = secret, + audiences = Nil, + domains = domains, + order = 1, + attributes = None, + loginBaseUrl = baseUrl, + graphBaseUrl = baseUrl + ) + ) + + "MsEntraGraphClient" should "return a lowercase samAccountName without the domain prefix" in { + client().resolveUsername("john.smith@example.com") shouldBe Some("un123xy") + lastTokenRequestBody should include("grant_type=client_credentials") + lastTokenRequestBody should include("client_id=client-id") + lastTokenRequestBody should include("client_secret=test-secret") + lastGraphAuthorization shouldBe "Bearer graph-token" + lastGraphPath shouldBe "/v1.0/users/john.smith@example.com" + lastGraphRawPath shouldBe "/v1.0/users/john.smith%40example.com" + lastGraphQuery shouldBe "$select=onPremisesSamAccountName,onPremisesDomainName" + } + + it should "return None when the user's domain is not in the configured allow-list" in { + graphResponseBody = """{"onPremisesSamAccountName":"UN123XY","onPremisesDomainName":"unknown.domain"}""" + + client().resolveUsername("john.smith@example.com") shouldBe None + } + + it should "return None when Graph does not provide on-premises attributes" in { + graphResponseBody = """{"displayName":"John Smith"}""" + + client().resolveUsername("john.smith@example.com") shouldBe None + } + + it should "return None when the token endpoint response has no access token" in { + tokenResponseBody = """{"token_type":"Bearer"}""" + + client().resolveUsername("john.smith@example.com") shouldBe None + } + + it should "return None when the Graph API responds with an error" in { + graphStatus = 500 + graphResponseBody = """{"error":"server exploded"}""" + + client().resolveUsername("john.smith@example.com") shouldBe None + } + + it should "return None when the client secret is missing" in { + client(secret = None).resolveUsername("john.smith@example.com") shouldBe None + } + + it should "allow direct testing of the mapped username helper" in { + client().resolveMappedUsername("UN123XY", "corp.dsarena.com", "john.smith@example.com") shouldBe Some("un123xy") + } +} diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidatorTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidatorTest.scala new file mode 100644 index 0000000..1056821 --- /dev/null +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/provider/entra/MsEntraTokenValidatorTest.scala @@ -0,0 +1,217 @@ +/* + * Copyright 2023 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.loginsvc.rest.provider.entra + +import com.nimbusds.jose.crypto.RSASSASigner +import com.nimbusds.jose.jwk.source.ImmutableJWKSet +import com.nimbusds.jose.jwk.{JWKSet, RSAKey} +import com.nimbusds.jose.proc.{SecurityContext => NimbusSecurityContext} +import com.nimbusds.jose.{JWSAlgorithm, JWSHeader} +import com.nimbusds.jwt.{JWTClaimsSet, SignedJWT} +import org.mockito.ArgumentMatchers.anyString +import org.mockito.Mockito.{mock, when} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import za.co.absa.loginsvc.rest.config.auth.MsEntraConfig + +import java.security.KeyPairGenerator +import java.util.{Date, UUID} +import scala.collection.JavaConverters._ +import scala.util.Success + +class MsEntraTokenValidatorTest extends AnyFlatSpec with Matchers { + + private val tenantId = "test-tenant-id" + private val clientId = "test-client-id" + private val audience = "api://test-client-id" + private val audience2 = "other-app-client-id" + private val issuer = s"https://login.microsoftonline.com/$tenantId/v2.0" + + private val config = MsEntraConfig( + tenantId = tenantId, + clientId = clientId, + audiences = List(audience, audience2), + order = 1, + attributes = Some(Map("email" -> "email")) + ) + + // Generate a real RSA key pair for testing + private val keyPairGenerator = KeyPairGenerator.getInstance("RSA") + keyPairGenerator.initialize(2048) + private val keyPair = keyPairGenerator.generateKeyPair() + private val rsaJwk = new RSAKey.Builder(keyPair.getPublic.asInstanceOf[java.security.interfaces.RSAPublicKey]) + .privateKey(keyPair.getPrivate) + .keyID(UUID.randomUUID().toString) + .build() + + private val jwkSet = new JWKSet(rsaJwk) + private val jwkSource = new ImmutableJWKSet[NimbusSecurityContext](jwkSet) + + private val validator = new MsEntraTokenValidator(config, Some(jwkSource)) + + private def buildToken( + subject: String = "user-oid-123", + preferredUsername: String = "user@example.com", + groups: Seq[String] = Seq("group1", "group2"), + email: String = "user@example.com", + expiresInSeconds: Int = 3600, + issuerOverride: String = issuer, + audienceOverride: String = audience + ): String = { + val now = new Date() + val exp = new Date(now.getTime + expiresInSeconds * 1000L) + + val claims = new JWTClaimsSet.Builder() + .subject(subject) + .issuer(issuerOverride) + .audience(audienceOverride) + .issueTime(now) + .expirationTime(exp) + .claim("preferred_username", preferredUsername) + .claim("groups", groups.asJava) + .claim("email", email) + .build() + + val header = new JWSHeader.Builder(JWSAlgorithm.RS256) + .keyID(rsaJwk.getKeyID) + .build() + + val jwt = new SignedJWT(header, claims) + jwt.sign(new RSASSASigner(rsaJwk)) + jwt.serialize() + } + + "MsEntraTokenValidator" should "return a User for a valid token" in { + val token = buildToken() + val result = validator.validate(token) + + result shouldBe a[Success[_]] + val user = result.get + user.name shouldBe "user@example.com" + user.groups should contain theSameElementsAs Seq("group1", "group2") + } + + it should "map configured attribute claims to optional attributes" in { + val token = buildToken(email = "mapped@example.com") + val user = validator.validate(token).get + user.optionalAttributes.get("email") shouldBe Some(Some("mapped@example.com")) + } + + it should "use 'upn' claim as username when preferred_username is absent" in { + val now = new Date() + val exp = new Date(now.getTime + 3600 * 1000L) + val claims = new JWTClaimsSet.Builder() + .subject("sub-id") + .issuer(issuer) + .audience(audience) + .issueTime(now) + .expirationTime(exp) + .claim("upn", "upnuser@example.com") + .claim("groups", Seq.empty[String].asJava) + .build() + val jwt = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJwk.getKeyID).build(), claims) + jwt.sign(new RSASSASigner(rsaJwk)) + + val user = validator.validate(jwt.serialize()).get + user.name shouldBe "upnuser@example.com" + } + + it should "fall back to sub claim as username when neither preferred_username nor upn is present" in { + val now = new Date() + val exp = new Date(now.getTime + 3600 * 1000L) + val claims = new JWTClaimsSet.Builder() + .subject("sub-only-user") + .issuer(issuer) + .audience(audience) + .issueTime(now) + .expirationTime(exp) + .claim("groups", Seq.empty[String].asJava) + .build() + val jwt = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJwk.getKeyID).build(), claims) + jwt.sign(new RSASSASigner(rsaJwk)) + + val user = validator.validate(jwt.serialize()).get + user.name shouldBe "sub-only-user" + } + + it should "return a Failure for an expired token" in { + // Use -120s to exceed Nimbus's default 60s clock-skew tolerance + val token = buildToken(expiresInSeconds = -120) + val result = validator.validate(token) + result.isFailure shouldBe true + } + + it should "return a Failure for a token with wrong issuer" in { + val token = buildToken(issuerOverride = "https://evil.example.com") + val result = validator.validate(token) + result.isFailure shouldBe true + } + + it should "accept a token with the second configured audience" in { + val token = buildToken(audienceOverride = audience2) + validator.validate(token) shouldBe a[Success[_]] + } + + it should "return a Failure for a token with wrong audience" in { + val token = buildToken(audienceOverride = "api://different-client") + val result = validator.validate(token) + result.isFailure shouldBe true + } + + it should "return a Failure for a malformed token string" in { + val result = validator.validate("not.a.valid.jwt") + result.isFailure shouldBe true + } + + it should "return empty groups when groups claim is absent" in { + val now = new Date() + val exp = new Date(now.getTime + 3600 * 1000L) + val claims = new JWTClaimsSet.Builder() + .subject("sub-id") + .issuer(issuer) + .audience(audience) + .issueTime(now) + .expirationTime(exp) + .claim("preferred_username", "nogroups@example.com") + .build() + val jwt = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJwk.getKeyID).build(), claims) + jwt.sign(new RSASSASigner(rsaJwk)) + + val user = validator.validate(jwt.serialize()).get + user.groups shouldBe empty + } + + it should "use the normalized samAccountName from Graph when graphClientOverride resolves the username" in { + val mockGraph = mock(classOf[GraphUsernameResolver]) + when(mockGraph.resolveUsername("user@example.com")).thenReturn(Some("jsmith")) + + val validatorWithGraph = new MsEntraTokenValidator(config, Some(jwkSource), Some(mockGraph)) + val token = buildToken() + val user = validatorWithGraph.validate(token).get + user.name shouldBe "jsmith" + } + + it should "fall back to UPN when the graph resolver returns None" in { + val mockGraph = mock(classOf[GraphUsernameResolver]) + when(mockGraph.resolveUsername(anyString())).thenReturn(None) + + val validatorWithGraph = new MsEntraTokenValidator(config, Some(jwkSource), Some(mockGraph)) + val token = buildToken(preferredUsername = "fallback@example.com") + val user = validatorWithGraph.validate(token).get + user.name shouldBe "fallback@example.com" + } +} diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/service/actuator/LdapHealthServiceTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/service/actuator/LdapHealthServiceTest.scala index 8a7cfcc..fc2c475 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/service/actuator/LdapHealthServiceTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/service/actuator/LdapHealthServiceTest.scala @@ -20,7 +20,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import org.springframework.boot.actuate.health.Health import org.springframework.boot.test.context.SpringBootTest -import za.co.absa.loginsvc.rest.config.auth.{ActiveDirectoryLDAPConfig, LdapUserCredentialsConfig, ServiceAccountConfig, UsersConfig} +import za.co.absa.loginsvc.rest.config.auth.{ActiveDirectoryLDAPConfig, LdapUserCredentialsConfig, MsEntraConfig, ServiceAccountConfig, UsersConfig} import za.co.absa.loginsvc.rest.config.provider.AuthConfigProvider import javax.naming.CommunicationException @@ -55,8 +55,8 @@ class LdapHealthServiceTest extends AnyFlatSpec with Matchers { "LdapHealthService" should "Return Up on Order 0" in { val configProvider = new AuthConfigProvider { override def getLdapConfig: Option[ActiveDirectoryLDAPConfig] = Some(ldapCfgZeroOrder) - override def getUsersConfig: Option[UsersConfig] = None + override def getMsEntraConfig: Option[MsEntraConfig] = None } val ldapHealthService: LdapHealthService = new testLdapHealthService(configProvider) val health = ldapHealthService.health() @@ -67,8 +67,8 @@ class LdapHealthServiceTest extends AnyFlatSpec with Matchers { "LdapHealthService" should "Return Up when ActiveDirectoryLDAPConfig is None" in { val configProvider = new AuthConfigProvider { override def getLdapConfig: Option[ActiveDirectoryLDAPConfig] = None - override def getUsersConfig: Option[UsersConfig] = None + override def getMsEntraConfig: Option[MsEntraConfig] = None } val ldapHealthService: LdapHealthService = new testLdapHealthService(configProvider) val health = ldapHealthService.health() @@ -79,8 +79,8 @@ class LdapHealthServiceTest extends AnyFlatSpec with Matchers { "LdapHealthService" should "Return Down when Ldap connection fails" in { val configProvider = new AuthConfigProvider { override def getLdapConfig: Option[ActiveDirectoryLDAPConfig] = Some(ldapCfgZeroOrder.copy(order = 2)) - override def getUsersConfig: Option[UsersConfig] = None + override def getMsEntraConfig: Option[MsEntraConfig] = None } val ldapHealthService: LdapHealthService = new testLdapHealthService(configProvider) val health = ldapHealthService.health() diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/service/search/DefaultUserRepositoriesTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/service/search/DefaultUserRepositoriesTest.scala index 5ca7c0a..eb0cb90 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/service/search/DefaultUserRepositoriesTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/service/search/DefaultUserRepositoriesTest.scala @@ -20,7 +20,7 @@ import org.scalamock.scalatest.MockFactory import org.scalatest.BeforeAndAfterEach import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import za.co.absa.loginsvc.rest.config.auth.{ActiveDirectoryLDAPConfig, LdapUserCredentialsConfig, ServiceAccountConfig, UserConfig, UsersConfig} +import za.co.absa.loginsvc.rest.config.auth.{ActiveDirectoryLDAPConfig, LdapUserCredentialsConfig, MsEntraConfig, ServiceAccountConfig, UserConfig, UsersConfig} import za.co.absa.loginsvc.rest.config.provider.{AuthConfigProvider, ConfigProvider} import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException @@ -39,6 +39,7 @@ class DefaultUserRepositoriesTest extends AnyFlatSpec with BeforeAndAfterEach wi new AuthConfigProvider { override def getLdapConfig: Option[ActiveDirectoryLDAPConfig] = optLdapConfig override def getUsersConfig: Option[UsersConfig] = optUsersConfig + override def getMsEntraConfig: Option[MsEntraConfig] = None } } diff --git a/build.sbt b/build.sbt index e7ab859..36cc3d6 100644 --- a/build.sbt +++ b/build.sbt @@ -30,6 +30,8 @@ lazy val commonJacocoExcludes: Seq[String] = Seq( lazy val commonJavacOptions = Seq("-source", "1.8", "-target", "1.8", "-Xlint") // deliberately making backwards compatible with J8 +addCommandAlias("runLocal", "api/run --spring.config.location=api/src/main/resources/local.application.yaml") + lazy val parent = (project in file(".")) .aggregate(api, clientLibrary, examples) .enablePlugins(FilteredJacocoAgentPlugin) @@ -50,7 +52,10 @@ lazy val api = project // no need to define file, because path is same as val na webappWebInfClasses := true, inheritJarManifest := true, javacOptions ++= commonJavacOptions, - publish / skip := true + publish / skip := true, + run / fork := true, // required: avoids URLStreamHandlerFactory conflict with xsbt-web-plugin + run / baseDirectory := (ThisBuild / baseDirectory).value, // forked process runs from repo root + run / javaOptions += "-Djavax.net.ssl.trustStoreType=KeychainStore" // use macOS Keychain CA certs (corporate proxy) ).enablePlugins(TomcatPlugin) .enablePlugins(AutomateHeaderPlugin) .enablePlugins(FilteredJacocoAgentPlugin) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 9a6f0a4..c9cef47 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -74,6 +74,9 @@ object Dependencies { // Enables /actuator/health endpoint lazy val springBootStarterActuator = "org.springframework.boot" % "spring-boot-starter-actuator" % Versions.springBoot + // guava Cache for JWKS caching in Entra token validation + lazy val cacheBuilderApi = "com.google.guava" % "guava" % "33.0.0-jre" + lazy val scalaTest = "org.scalatest" %% "scalatest" % Versions.scalatest % Test lazy val springBootTest = "org.springframework.boot" % "spring-boot-starter-test" % Versions.springBoot % Test lazy val springBootSecurityTest = "org.springframework.security" % "spring-security-test" % Versions.spring % Test @@ -111,6 +114,8 @@ object Dependencies { springBootStarterActuator, + cacheBuilderApi, + scalaTest, springBootTest, springBootSecurityTest,