Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.modelix.kotlin.utils

import kotlin.annotation.AnnotationTarget.FUNCTION
import kotlin.annotation.AnnotationTarget.PROPERTY_GETTER
import kotlin.annotation.AnnotationTarget.PROPERTY_SETTER

@OptIn(ExperimentalMultiplatform::class)
@Target(FUNCTION, PROPERTY_GETTER, PROPERTY_SETTER)
@MustBeDocumented
@OptionalExpectation
expect annotation class JvmSynchronized()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.modelix.kotlin.utils

actual typealias JvmSynchronized = kotlin.jvm.Synchronized
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package org.modelix.model.client2

internal expect interface Closable {
expect interface Closable {
fun close()
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ expect class ModelixAuthClient() {
)
}

internal fun installAuthWithAuthTokenProvider(config: HttpClientConfig<*>, authTokenProvider: suspend () -> String?) {
internal fun installAuthWithAuthTokenProvider(config: HttpClientConfig<*>, authConfig: TokenProviderAuthConfig) {
config.apply {
install(Auth) {
bearer {
loadTokens {
authTokenProvider()?.let { authToken -> BearerTokens(authToken, "") }
authConfig.provider.getToken(authConfig.tokenParameters)
?.let { authToken -> BearerTokens(authToken, "") }
}
refreshTokens {
val providedToken = authTokenProvider()
val providedToken = authConfig.provider.getToken(authConfig.tokenParameters)
if (providedToken != null && providedToken != this.oldTokens?.accessToken) {
BearerTokens(providedToken, "")
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,113 @@
package org.modelix.model.oauth

import org.modelix.model.lazy.BranchReference
import org.modelix.model.lazy.RepositoryId

sealed interface IAuthConfig {
companion object {
@Deprecated("Provide an ITokenProvider")
fun fromTokenProvider(provider: TokenProvider): IAuthConfig {
return fromTokenProvider(TokenProviderAdapter(provider))
}

fun fromTokenProvider(provider: ITokenProvider): IAuthConfig {
return TokenProviderAuthConfig(provider)
}

fun oauth(body: OAuthConfigBuilder.() -> Unit): IAuthConfig {
return OAuthConfigBuilder(null).apply(body).build()
}
}

fun withTokenParameters(parameters: ITokenParameters): IAuthConfig
}

class TokenProviderAuthConfig(val provider: TokenProvider) : IAuthConfig
class TokenProviderAuthConfig(
val provider: ITokenProvider,
val tokenParameters: ITokenParameters = GlobalTokenParameters(),
) : IAuthConfig {
override fun withTokenParameters(parameters: ITokenParameters): IAuthConfig {
return TokenProviderAuthConfig(provider, parameters)
}
}

data class OAuthConfig(
val clientId: String? = "external-mps",
val clientSecret: String? = null,
val authorizationUrl: String? = null,
val tokenUrl: String? = null,
val repositoryId: RepositoryId? = null,
val tokenParameters: ITokenParameters? = null,
val scopes: Set<String> = emptySet(),
val authRequestHandler: IAuthRequestHandler? = null,
) : IAuthConfig
) : IAuthConfig {
fun getCacheKey() = TokenCacheKey(
clientId = clientId,
clientSecret = clientSecret,
authorizationUrl = authorizationUrl,
tokenUrl = tokenUrl,
scopes = scopes,
)

override fun withTokenParameters(parameters: ITokenParameters): IAuthConfig {
return copy(tokenParameters = parameters)
}
}

data class TokenCacheKey(
val clientId: String?,
val clientSecret: String?,
val authorizationUrl: String?,
val tokenUrl: String?,
val scopes: Set<String>,
)

@Deprecated("use ITokenProvider")
typealias TokenProvider = suspend () -> String?

interface ITokenProvider {
suspend fun getToken(): String? = null
suspend fun getToken(parameters: ITokenParameters): String? = getToken()
}

class TokenProviderAdapter(val provider: suspend () -> String?) : ITokenProvider {
override suspend fun getToken(): String? = provider()
}

interface ITokenParameters {
fun getRepositoryId(): String?
fun getBranchName(): String?
}

class TokenParameters(private val repositoryId: RepositoryId?, private val branchName: String?) : ITokenParameters {
constructor(repositoryId: RepositoryId) : this(repositoryId, null)
constructor(branchReference: BranchReference) : this(branchReference.repositoryId, branchReference.branchName)

private var dependsOnRepositoryId: Boolean = false
private var dependsOnBranchName: Boolean = false

override fun getRepositoryId(): String? {
dependsOnRepositoryId = true
return repositoryId?.id
}

override fun getBranchName(): String? {
dependsOnBranchName = true
return branchName
}

fun createCacheKey(): Any {
return listOf(
repositoryId?.id?.takeIf { dependsOnRepositoryId },
branchName?.takeIf { dependsOnBranchName },
)
}
}

class GlobalTokenParameters : ITokenParameters {
override fun getRepositoryId(): String? = null
override fun getBranchName(): String? = null
}

class OAuthConfigBuilder(initial: OAuthConfig?) {
private var config = initial ?: OAuthConfig()

Expand All @@ -39,7 +119,13 @@ class OAuthConfigBuilder(initial: OAuthConfig?) {
fun additionalScope(scope: String) = additionalScopes(setOf(scope))
fun authorizationUrl(url: String) = also { config = config.copy(authorizationUrl = url) }
fun tokenUrl(url: String) = also { config = config.copy(tokenUrl = url) }
fun repositoryId(repositoryId: RepositoryId) = also { config = config.copy(repositoryId = repositoryId) }
fun repositoryId(repositoryId: RepositoryId) = tokenParameters(
TokenParameters(
repositoryId = repositoryId,
branchName = null,
),
)
fun tokenParameters(parameters: ITokenParameters) = also { config = config.copy(tokenParameters = parameters) }
fun oidcUrl(url: String) = authorizationUrl(url.trimEnd('/') + "/auth").tokenUrl(url.trimEnd('/') + "/token")
fun authRequestHandler(handler: IAuthRequestHandler?) = also { config = config.copy(authRequestHandler = handler) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModelClientV2Test {
val exception = assertFailsWith<CancellationException> {
modelClient.init()
}
assertEquals("Parent job is Completed", exception.message)
assertEquals("Already closed", exception.message)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package org.modelix.model.client2

internal actual interface Closable {
actual interface Closable {
actual fun close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package org.modelix.model.client2

import io.ktor.client.HttpClient
import io.ktor.client.engine.js.Js
import org.modelix.model.oauth.ITokenParameters

actual class ModelClientV2PlatformSpecificBuilder : ModelClientV2Builder() {
actual override fun createHttpClient(): HttpClient {
actual override fun createHttpClient(tokenParameters: ITokenParameters): HttpClient {
return HttpClient(Js) {
configureHttpClient(this)
configureHttpClient(this, tokenParameters)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ actual class ModelixAuthClient {
authConfig: IAuthConfig,
) {
when (authConfig) {
is OAuthConfig -> UnsupportedOperationException("JS client doesn't support OAuth2")
is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig.provider)
is OAuthConfig -> throw UnsupportedOperationException("JS client doesn't support OAuth2")
is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ import org.modelix.model.KeyValueStoreCache
import org.modelix.model.api.IIdGenerator
import org.modelix.model.lazy.IDeserializingKeyValueStore
import org.modelix.model.lazy.createObjectStoreCache
import org.modelix.model.oauth.ModelixAuthClient
import org.modelix.model.persistent.HashUtil
import org.modelix.model.sleep
import org.modelix.model.util.StreamUtils.toStream
Expand Down Expand Up @@ -149,13 +148,12 @@ class RestWebModelClient @JvmOverloads constructor(
},
)
}
val modelixAuthClient = ModelixAuthClient()
install(Auth) {
bearer {
loadTokens {
val tp = authTokenProvider
if (tp == null) {
modelixAuthClient.getTokens()?.let { BearerTokens(it.accessToken, it.refreshToken) }
null
} else {
val token = tp()
if (token == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package org.modelix.model.client2

internal actual interface Closable : java.io.Closeable {
actual interface Closable : java.io.Closeable {

Check warning

Code scanning / detekt

Closable is missing required documentation. Warning

Closable is missing required documentation.
actual override fun close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package org.modelix.model.client2

import io.ktor.client.HttpClient
import io.ktor.client.engine.cio.CIO
import org.modelix.model.oauth.ITokenParameters

actual class ModelClientV2PlatformSpecificBuilder : ModelClientV2Builder() {
actual override fun createHttpClient(): HttpClient {
actual override fun createHttpClient(tokenParameters: ITokenParameters): HttpClient {
return HttpClient(CIO) {
configureHttpClient(this)
configureHttpClient(this, tokenParameters)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.modelix.kotlin.utils.runSynchronized
import org.modelix.kotlin.utils.urlEncode
import java.net.SocketException
import java.net.SocketTimeoutException
Expand All @@ -44,16 +45,19 @@
}
private val httpTransport: HttpTransport = NetHttpTransport()
private val jsonFactory: JsonFactory = GsonFactory()
private var lastCredentials: Credential? = null
private var lastCredentials = HashMap<TokenCacheKey, Credential>()

Check warning

Code scanning / detekt

Variable lastCredentials is declared as var with a mutable type java.util.HashMap. Consider using val or an immutable collection or value type Warning

Variable lastCredentials is declared as var with a mutable type java.util.HashMap. Consider using val or an immutable collection or value type

Check warning

Code scanning / detekt

Variable 'lastCredentials' could be val. Warning

Variable 'lastCredentials' could be val.

fun getTokens(): Credential? {
return lastCredentials?.takeIf { !it.isExpired() }
@Synchronized
fun getTokens(config: OAuthConfig): Credential? {

Check warning

Code scanning / detekt

The function getTokens is missing documentation. Warning

The function getTokens is missing documentation.
return lastCredentials[config.getCacheKey()]?.takeIf { !it.isExpired() }
}

@Synchronized
private fun Credential.isExpired(): Boolean {
return (expiresInSeconds ?: return false) < 60
}

@Synchronized
private fun Credential.refreshIfExpired(): Credential? {
return if (isExpired()) {
alwaysRefresh()
Expand All @@ -62,6 +66,7 @@
}
}

@Synchronized
fun Credential.alwaysRefresh(): Credential? {
for (attempt in 1..3) {
try {
Expand All @@ -77,12 +82,13 @@
return null
}

suspend fun getAndMaybeRefreshTokens(): Credential? {
return lastCredentials?.refreshIfExpired()
@Synchronized
fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? {

Check warning

Code scanning / detekt

The function getAndMaybeRefreshTokens is missing documentation. Warning

The function getAndMaybeRefreshTokens is missing documentation.
return lastCredentials[config.getCacheKey()]?.refreshIfExpired()
}

suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? {
return lastCredentials?.alwaysRefresh() ?: authorize(config)
return runSynchronized(this) { lastCredentials[config.getCacheKey()]?.alwaysRefresh() } ?: authorize(config)
}

suspend fun authorize(config: OAuthConfig): Credential? {
Expand Down Expand Up @@ -110,7 +116,9 @@
val tokens = cancelable({ receiver.stop() }) {
AuthorizationCodeInstalledApp(flow, receiver, browser).authorize(null)
}
lastCredentials = tokens
runSynchronized(this@ModelixAuthClient) {
lastCredentials[config.getCacheKey()] = tokens
}
return@withContext tokens
} catch (ex: SocketException) {
LOG.info("Port $port already in use. Trying next one.")
Expand All @@ -127,7 +135,7 @@
authConfig: IAuthConfig,
) {
when (authConfig) {
is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig.provider)
is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig)
is OAuthConfig -> installAuthWithPKCEFlow(config, authConfig)
}
}
Expand All @@ -139,11 +147,19 @@
var currentAuthConfig = initialAuthConfig

fun String.fillParameters(): String {
return if (initialAuthConfig.repositoryId == null) {
this
} else {
replace("{repositoryId}", initialAuthConfig.repositoryId.id.urlEncode())
val tokenParameters = currentAuthConfig.tokenParameters ?: return this
var result = this
if (result.contains("{repositoryId}")) {
result = result.replace("{repositoryId}", tokenParameters.getRepositoryId().orEmpty().urlEncode())
}
if (result.contains("{branchName}")) {
result = result.replace("{branchName}", tokenParameters.getBranchName().orEmpty().urlEncode())
}
return result
}

fun OAuthConfig.fillParameters(): OAuthConfig {
return copy(tokenUrl = tokenUrl?.fillParameters(), authorizationUrl = authorizationUrl?.fillParameters())
}

config.apply {
Expand All @@ -152,7 +168,7 @@
loadTokens {
// A potentially expired token is already refreshed here to avoid a 401 response.
// When a 401 response is received, we always (re-)execute the PKCE flow.
getAndMaybeRefreshTokens()?.let { BearerTokens(it.accessToken, it.refreshToken) }
getAndMaybeRefreshTokens(currentAuthConfig)?.let { BearerTokens(it.accessToken, it.refreshToken) }
}
refreshTokens {
try {
Expand Down Expand Up @@ -181,7 +197,7 @@
LOG.warn { "No client ID configured" }
return@refreshTokens null
}
val tokens = refreshTokensOrReauthorize(currentAuthConfig)
val tokens = refreshTokensOrReauthorize(currentAuthConfig.fillParameters())
checkNotNull(tokens) { "No tokens received" }

LOG.info("Access Token: " + tokens.accessToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class ModelClientV2JvmTest {
val firstException = assertFailsWith<CancellationException> {
modelClient.init()
}
assertEquals("Parent job is Completed", firstException.message)
assertEquals("Already closed", firstException.message)
// `Closable` implies that `.close` method is idempotent.
modelClient.close()
val secondException = assertFailsWith<CancellationException> {
modelClient.init()
}
assertEquals("Parent job is Completed", secondException.message)
assertEquals("Already closed", secondException.message)
}
}
Loading
Loading