diff --git a/kotlin-utils/src/commonMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt b/kotlin-utils/src/commonMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt new file mode 100644 index 0000000000..e6a01ec70c --- /dev/null +++ b/kotlin-utils/src/commonMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt @@ -0,0 +1,11 @@ +package org.modelix.kotlin.utils + +import kotlin.annotation.AnnotationTarget.FUNCTION +import kotlin.annotation.AnnotationTarget.PROPERTY_GETTER +import kotlin.annotation.AnnotationTarget.PROPERTY_SETTER + +@OptIn(ExperimentalMultiplatform::class) +@Target(FUNCTION, PROPERTY_GETTER, PROPERTY_SETTER) +@MustBeDocumented +@OptionalExpectation +expect annotation class JvmSynchronized() diff --git a/kotlin-utils/src/jvmMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt b/kotlin-utils/src/jvmMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt new file mode 100644 index 0000000000..a7c0c80e17 --- /dev/null +++ b/kotlin-utils/src/jvmMain/kotlin/org/modelix/kotlin/utils/JvmSynchronized.kt @@ -0,0 +1,3 @@ +package org.modelix.kotlin.utils + +actual typealias JvmSynchronized = kotlin.jvm.Synchronized diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/client2/Closable.kt b/model-client/src/commonMain/kotlin/org/modelix/model/client2/Closable.kt index c95b44b36a..5ea83f656b 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/client2/Closable.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/client2/Closable.kt @@ -1,5 +1,5 @@ package org.modelix.model.client2 -internal expect interface Closable { +expect interface Closable { fun close() } diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/client2/ModelClientV2.kt b/model-client/src/commonMain/kotlin/org/modelix/model/client2/ModelClientV2.kt index 1882084248..0f68dd1c95 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/client2/ModelClientV2.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/client2/ModelClientV2.kt @@ -34,6 +34,7 @@ import io.ktor.http.contentType import io.ktor.http.takeFrom import io.ktor.serialization.kotlinx.json.json import io.ktor.utils.io.readUTF8Line +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.flow @@ -52,6 +53,7 @@ import org.modelix.datastructures.objects.IObjectGraph import org.modelix.datastructures.objects.Object import org.modelix.datastructures.objects.ObjectHash import org.modelix.kotlin.utils.DeprecationInfo +import org.modelix.kotlin.utils.JvmSynchronized import org.modelix.kotlin.utils.WeakValueMap import org.modelix.kotlin.utils.getOrPut import org.modelix.kotlin.utils.runSynchronized @@ -79,12 +81,17 @@ import org.modelix.model.mutable.IMutableModelTree import org.modelix.model.mutable.INodeIdGenerator import org.modelix.model.mutable.ModelixIdGenerator import org.modelix.model.mutable.getRootNode +import org.modelix.model.oauth.GlobalTokenParameters import org.modelix.model.oauth.IAuthConfig import org.modelix.model.oauth.IAuthRequestHandler +import org.modelix.model.oauth.ITokenParameters +import org.modelix.model.oauth.ITokenProvider import org.modelix.model.oauth.ModelixAuthClient import org.modelix.model.oauth.OAuthConfig import org.modelix.model.oauth.OAuthConfigBuilder +import org.modelix.model.oauth.TokenParameters import org.modelix.model.oauth.TokenProvider +import org.modelix.model.oauth.TokenProviderAdapter import org.modelix.model.oauth.TokenProviderAuthConfig import org.modelix.model.operations.OTBranch import org.modelix.model.persistent.CPVersion @@ -108,7 +115,7 @@ import kotlin.time.Duration.Companion.seconds class VersionNotFoundException(val versionHash: String) : Exception("Version $versionHash not found") class ModelClientV2( - val httpClient: HttpClient, + val httpClientProvider: IHttpClientProvider, val baseUrl: String, private var clientProvidedUserId: String?, var defaultGraphConfig: ModelClientGraphConfig, @@ -152,7 +159,7 @@ class ModelClientV2( } override suspend fun getServerId(): String { - return httpClient.get { + return httpClientProvider.getHttpClient().get { url { takeFrom(baseUrl) appendPathSegments("server-id") @@ -161,7 +168,7 @@ class ModelClientV2( } private suspend fun updateClientId() { - this.clientId = httpClient.post { + this.clientId = httpClientProvider.getHttpClient().post { url { takeFrom(baseUrl) appendPathSegments("generate-client-id") @@ -171,7 +178,7 @@ class ModelClientV2( } suspend fun updateUserId() { - serverProvidedUserId = httpClient.get { + serverProvidedUserId = httpClientProvider.getHttpClient().get { url { takeFrom(baseUrl) appendPathSegments("user-id") @@ -229,7 +236,7 @@ class ModelClientV2( override suspend fun initRepository(config: RepositoryConfig): IVersion { val repositoryId = RepositoryId(config.repositoryId) - return httpClient.preparePost { + return httpClientProvider.getHttpClient(RepositoryId(config.repositoryId)).preparePost { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repositoryId.id, "init") @@ -246,7 +253,7 @@ class ModelClientV2( source: RepositoryId, target: RepositoryId?, ): RepositoryId { - return httpClient.preparePost { + return httpClientProvider.getHttpClient(source).preparePost { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", source.id, "fork") @@ -261,7 +268,7 @@ class ModelClientV2( override suspend fun changeRepositoryConfig(config: RepositoryConfig): RepositoryConfig { val repositoryId = RepositoryId(config.repositoryId) - return httpClient.preparePost { + return httpClientProvider.getHttpClient(RepositoryId(config.repositoryId)).preparePost { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repositoryId.id, "config") @@ -274,7 +281,7 @@ class ModelClientV2( } override suspend fun getRepositoryConfig(repository: RepositoryId): RepositoryConfig { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(repository).prepareGet { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repository.id, "config") @@ -283,7 +290,7 @@ class ModelClientV2( } override suspend fun listRepositories(): List { - return httpClient.get { + return httpClientProvider.getHttpClient().get { url { takeFrom(baseUrl) appendPathSegments("repositories") @@ -293,7 +300,7 @@ class ModelClientV2( override suspend fun deleteRepository(repository: RepositoryId): Boolean { try { - return httpClient.post { + return httpClientProvider.getHttpClient(repository).post { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repository.id, "delete") @@ -306,7 +313,7 @@ class ModelClientV2( } override suspend fun listBranches(repository: RepositoryId): List { - return httpClient.get { + return httpClientProvider.getHttpClient(repository).get { // only accept text/plain, not application/json accept(ContentType.Text.Plain) exclude(ContentType.Application.Json) @@ -318,7 +325,7 @@ class ModelClientV2( } override suspend fun listBranchesWithHashes(repository: RepositoryId): List { - return httpClient.get { + return httpClientProvider.getHttpClient(repository).get { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repository.id, "branches") @@ -328,7 +335,7 @@ class ModelClientV2( override suspend fun deleteBranch(branch: BranchReference): Boolean { try { - return httpClient.delete { + return httpClientProvider.getHttpClient(branch).delete { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash( @@ -380,7 +387,7 @@ class ModelClientV2( delay: Duration, pagination: PaginationParameters, ): List { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(repositoryId).prepareGet { url { takeFrom(baseUrl) appendPathSegments("repositories", repositoryId.id, "versions", headVersion.toString(), "history", "sessions") @@ -402,7 +409,7 @@ class ModelClientV2( interval: Duration, pagination: PaginationParameters, ): List { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(repositoryId).prepareGet { url { takeFrom(baseUrl) appendPathSegments("repositories", repositoryId.id, "versions", headVersion.toString(), "history", "intervals") @@ -423,7 +430,7 @@ class ModelClientV2( timeRange: ClosedRange?, pagination: PaginationParameters, ): List { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(repositoryId).prepareGet { url { takeFrom(baseUrl) appendPathSegments("repositories", repositoryId.id, "versions", headVersion.toString(), "history", "entries") @@ -440,7 +447,7 @@ class ModelClientV2( } override suspend fun splitAt(splitPoints: List): List { - return httpClient.preparePost { + return httpClientProvider.getHttpClient(repositoryId).preparePost { url { takeFrom(baseUrl) appendPathSegments("repositories", repositoryId.id, "versions", headVersion.toString(), "history", "intervals") @@ -457,6 +464,11 @@ class ModelClientV2( repositoryId: RepositoryId?, versionHash: ObjectHash, ): Object { + val httpClient = if (repositoryId == null) { + httpClientProvider.getHttpClient() + } else { + httpClientProvider.getHttpClient(repositoryId) + } return httpClient.prepareGet { url { takeFrom(baseUrl) @@ -482,7 +494,7 @@ class ModelClientV2( branch: BranchReference, versionHash: ObjectHash, ): ObjectHash { - return httpClient.preparePost { + return httpClientProvider.getHttpClient(branch).preparePost { url { takeFrom(baseUrl) appendPathSegments("repositories", branch.repositoryId.id, "branches", branch.branchName, "revert") @@ -521,6 +533,11 @@ class ModelClientV2( baseVersion: IVersion?, ): IVersion { checkCreatedByThisClient(baseVersion, repositoryId) + val httpClient = if (repositoryId == null) { + httpClientProvider.getHttpClient() + } else { + httpClientProvider.getHttpClient(repositoryId) + } return httpClient.prepareGet { url { takeFrom(baseUrl) @@ -541,7 +558,7 @@ class ModelClientV2( override suspend fun getObjects(repository: RepositoryId, keys: Sequence): Map { LOG.debug { "${clientId.toString(16)}.getObjects($repository)" } - return httpClient.preparePost { + return httpClientProvider.getHttpClient(repository).preparePost { url { takeFrom(baseUrl) appendPathSegments("repositories", repository.id, "objects", "getAll") @@ -584,7 +601,7 @@ class ModelClientV2( VersionDelta(version.getContentHash(), null, objectsMap = lastChunk.toMap()) } - return httpClient.preparePost { + return httpClientProvider.getHttpClient(branch).preparePost { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName) @@ -630,7 +647,7 @@ class ModelClientV2( val chunkEntries = ArrayList() suspend fun sendChunk() { - httpClient.put { + httpClientProvider.getHttpClient(repository).put { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repository.id, "objects") @@ -660,7 +677,7 @@ class ModelClientV2( override suspend fun pull(branch: BranchReference, lastKnownVersion: IVersion?, filter: ObjectDeltaFilter): IVersion { require(lastKnownVersion is CLVersion?) checkCreatedByThisClient(lastKnownVersion, branch.repositoryId) - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(branch).prepareGet { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName) @@ -680,7 +697,7 @@ class ModelClientV2( } override suspend fun pullIfExists(branch: BranchReference): IVersion? { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(branch).prepareGet { expectSuccess = false url { takeFrom(baseUrl) @@ -699,7 +716,7 @@ class ModelClientV2( } override suspend fun pullHash(branch: BranchReference): String { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(branch).prepareGet { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName, "hash") @@ -711,7 +728,7 @@ class ModelClientV2( } override suspend fun pullHashIfExists(branch: BranchReference): String? { - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(branch).prepareGet { expectSuccess = false url { takeFrom(baseUrl) @@ -732,7 +749,7 @@ class ModelClientV2( } override suspend fun pollHash(branch: BranchReference, lastKnownHash: String?): String { - val response = httpClient.get { + val response = httpClientProvider.getHttpClient(branch).get { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName, "pollHash") @@ -749,7 +766,7 @@ class ModelClientV2( require(lastKnownVersion is CLVersion?) checkCreatedByThisClient(lastKnownVersion, branch.repositoryId) LOG.debug { "${clientId.toString(16)}.poll($branch, $lastKnownVersion)" } - return httpClient.prepareGet { + return httpClientProvider.getHttpClient(branch).prepareGet { url { takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName, "poll") @@ -774,7 +791,7 @@ class ModelClientV2( takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", branch.repositoryId.id, "branches", branch.branchName, "query") } - return ModelQLClient.builder().httpClient(httpClient).url(url.buildString()).build().query(body) + return ModelQLClient.builder().httpClient(httpClientProvider.getHttpClient(branch)).url(url.buildString()).build().query(body) } override suspend fun query(repositoryId: RepositoryId, versionHash: String, body: (IMonoStep) -> IMonoStep): R { @@ -782,7 +799,7 @@ class ModelClientV2( takeFrom(baseUrl) appendPathSegmentsEncodingSlash("repositories", repositoryId.id, "versions", versionHash, "query") } - return ModelQLClient.builder().httpClient(httpClient).url(url.buildString()).build().query(body) + return ModelQLClient.builder().httpClient(httpClientProvider.getHttpClient(repositoryId)).url(url.buildString()).build().query(body) } override fun getFrontendUrl(branch: BranchReference): Url { @@ -793,7 +810,7 @@ class ModelClientV2( } override fun close() { - httpClient.close() + httpClientProvider.close() } private suspend fun createVersion(repository: RepositoryId, baseVersion: CLVersion?, delta: VersionDeltaStream): CLVersion { @@ -846,13 +863,19 @@ abstract class ModelClientV2Builder { protected var connectTimeout: Duration = 1.seconds protected var requestTimeout: Duration = 300.seconds protected var defaultGraphConfig = ModelClientGraphConfig() + protected val authClient by lazy { ModelixAuthClient() } // 0 and 1 mean "disable retries" protected var retries: UInt = 3U fun build(): ModelClientV2 { return ModelClientV2( - httpClient = httpClient?.config { configureHttpClient(this) } ?: createHttpClient(), + httpClientProvider = object : CachingHttpClientProvider() { + override fun createInstance(tokenParameters: ITokenParameters): HttpClient { + return httpClient?.config { configureHttpClient(this, tokenParameters) } + ?: createHttpClient(tokenParameters) + } + }, baseUrl = baseUrl, clientProvidedUserId = userId, defaultGraphConfig = defaultGraphConfig, @@ -879,7 +902,10 @@ abstract class ModelClientV2Builder { return this } - fun authToken(provider: TokenProvider) = also { + @Deprecated("Provide an ITokenProvider") + fun authToken(provider: TokenProvider) = authToken(TokenProviderAdapter(provider)) + + fun authToken(provider: ITokenProvider) = also { authConfig = TokenProviderAuthConfig(provider) } @@ -970,7 +996,7 @@ abstract class ModelClientV2Builder { return this } - protected open fun configureHttpClient(config: HttpClientConfig<*>) { + protected open fun configureHttpClient(config: HttpClientConfig<*>, tokenParameters: ITokenParameters) { config.apply { expectSuccess = true followRedirects = false @@ -1001,11 +1027,11 @@ abstract class ModelClientV2Builder { } } } - authConfig?.let { ModelixAuthClient().installAuth(this, it) } + authConfig?.withTokenParameters(tokenParameters)?.let { authClient.installAuth(this, it) } } } - protected abstract fun createHttpClient(): HttpClient + protected abstract fun createHttpClient(tokenParameters: ITokenParameters): HttpClient companion object { private val LOG = KotlinLogging.logger {} @@ -1020,8 +1046,56 @@ abstract class ModelClientV2Builder { } } +interface IHttpClientProvider : Closable { + fun getHttpClient(): HttpClient + fun getHttpClient(repository: RepositoryId): HttpClient + fun getHttpClient(branch: BranchReference): HttpClient +} + +abstract class CachingHttpClientProvider : IHttpClientProvider { + private var serverInstance: HttpClient? = null + private val repositoryInstances = HashMap() + private val branchInstances = HashMap() + private var closed: Boolean = false + + abstract fun createInstance(tokenParameters: ITokenParameters): HttpClient + + @JvmSynchronized + override fun getHttpClient(): HttpClient { + checkClosed() + return serverInstance ?: createInstance(GlobalTokenParameters()).also { serverInstance = it } + } + + @JvmSynchronized + override fun getHttpClient(repository: RepositoryId): HttpClient { + checkClosed() + return repositoryInstances.getOrPut(repository) { createInstance(TokenParameters(repository)) } + } + + @JvmSynchronized + override fun getHttpClient(branch: BranchReference): HttpClient { + checkClosed() + return branchInstances.getOrPut(branch) { createInstance(TokenParameters(branch)) } + } + + private fun checkClosed() { + if (closed) throw CancellationException("Already closed") + } + + @JvmSynchronized + override fun close() { + closed = true + serverInstance?.close() + serverInstance = null + repositoryInstances.values.forEach { it.close() } + repositoryInstances.clear() + branchInstances.values.forEach { it.close() } + branchInstances.clear() + } +} + expect class ModelClientV2PlatformSpecificBuilder() : ModelClientV2Builder { - override fun createHttpClient(): HttpClient + override fun createHttpClient(tokenParameters: ITokenParameters): HttpClient } fun VersionDelta.checkObjectHashes() { diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt index 36e2571705..6566c7901c 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt @@ -24,15 +24,16 @@ expect class ModelixAuthClient() { ) } -internal fun installAuthWithAuthTokenProvider(config: HttpClientConfig<*>, authTokenProvider: suspend () -> String?) { +internal fun installAuthWithAuthTokenProvider(config: HttpClientConfig<*>, authConfig: TokenProviderAuthConfig) { config.apply { install(Auth) { bearer { loadTokens { - authTokenProvider()?.let { authToken -> BearerTokens(authToken, "") } + authConfig.provider.getToken(authConfig.tokenParameters) + ?.let { authToken -> BearerTokens(authToken, "") } } refreshTokens { - val providedToken = authTokenProvider() + val providedToken = authConfig.provider.getToken(authConfig.tokenParameters) if (providedToken != null && providedToken != this.oldTokens?.accessToken) { BearerTokens(providedToken, "") } else { diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt index f98c768fec..ac941aabde 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt @@ -1,10 +1,16 @@ package org.modelix.model.oauth +import org.modelix.model.lazy.BranchReference import org.modelix.model.lazy.RepositoryId sealed interface IAuthConfig { companion object { + @Deprecated("Provide an ITokenProvider") fun fromTokenProvider(provider: TokenProvider): IAuthConfig { + return fromTokenProvider(TokenProviderAdapter(provider)) + } + + fun fromTokenProvider(provider: ITokenProvider): IAuthConfig { return TokenProviderAuthConfig(provider) } @@ -12,22 +18,96 @@ sealed interface IAuthConfig { return OAuthConfigBuilder(null).apply(body).build() } } + + fun withTokenParameters(parameters: ITokenParameters): IAuthConfig } -class TokenProviderAuthConfig(val provider: TokenProvider) : IAuthConfig +class TokenProviderAuthConfig( + val provider: ITokenProvider, + val tokenParameters: ITokenParameters = GlobalTokenParameters(), +) : IAuthConfig { + override fun withTokenParameters(parameters: ITokenParameters): IAuthConfig { + return TokenProviderAuthConfig(provider, parameters) + } +} data class OAuthConfig( val clientId: String? = "external-mps", val clientSecret: String? = null, val authorizationUrl: String? = null, val tokenUrl: String? = null, - val repositoryId: RepositoryId? = null, + val tokenParameters: ITokenParameters? = null, val scopes: Set = emptySet(), val authRequestHandler: IAuthRequestHandler? = null, -) : IAuthConfig +) : IAuthConfig { + fun getCacheKey() = TokenCacheKey( + clientId = clientId, + clientSecret = clientSecret, + authorizationUrl = authorizationUrl, + tokenUrl = tokenUrl, + scopes = scopes, + ) + override fun withTokenParameters(parameters: ITokenParameters): IAuthConfig { + return copy(tokenParameters = parameters) + } +} + +data class TokenCacheKey( + val clientId: String?, + val clientSecret: String?, + val authorizationUrl: String?, + val tokenUrl: String?, + val scopes: Set, +) + +@Deprecated("use ITokenProvider") typealias TokenProvider = suspend () -> String? +interface ITokenProvider { + suspend fun getToken(): String? = null + suspend fun getToken(parameters: ITokenParameters): String? = getToken() +} + +class TokenProviderAdapter(val provider: suspend () -> String?) : ITokenProvider { + override suspend fun getToken(): String? = provider() +} + +interface ITokenParameters { + fun getRepositoryId(): String? + fun getBranchName(): String? +} + +class TokenParameters(private val repositoryId: RepositoryId?, private val branchName: String?) : ITokenParameters { + constructor(repositoryId: RepositoryId) : this(repositoryId, null) + constructor(branchReference: BranchReference) : this(branchReference.repositoryId, branchReference.branchName) + + private var dependsOnRepositoryId: Boolean = false + private var dependsOnBranchName: Boolean = false + + override fun getRepositoryId(): String? { + dependsOnRepositoryId = true + return repositoryId?.id + } + + override fun getBranchName(): String? { + dependsOnBranchName = true + return branchName + } + + fun createCacheKey(): Any { + return listOf( + repositoryId?.id?.takeIf { dependsOnRepositoryId }, + branchName?.takeIf { dependsOnBranchName }, + ) + } +} + +class GlobalTokenParameters : ITokenParameters { + override fun getRepositoryId(): String? = null + override fun getBranchName(): String? = null +} + class OAuthConfigBuilder(initial: OAuthConfig?) { private var config = initial ?: OAuthConfig() @@ -39,7 +119,13 @@ class OAuthConfigBuilder(initial: OAuthConfig?) { fun additionalScope(scope: String) = additionalScopes(setOf(scope)) fun authorizationUrl(url: String) = also { config = config.copy(authorizationUrl = url) } fun tokenUrl(url: String) = also { config = config.copy(tokenUrl = url) } - fun repositoryId(repositoryId: RepositoryId) = also { config = config.copy(repositoryId = repositoryId) } + fun repositoryId(repositoryId: RepositoryId) = tokenParameters( + TokenParameters( + repositoryId = repositoryId, + branchName = null, + ), + ) + fun tokenParameters(parameters: ITokenParameters) = also { config = config.copy(tokenParameters = parameters) } fun oidcUrl(url: String) = authorizationUrl(url.trimEnd('/') + "/auth").tokenUrl(url.trimEnd('/') + "/token") fun authRequestHandler(handler: IAuthRequestHandler?) = also { config = config.copy(authRequestHandler = handler) } diff --git a/model-client/src/commonTest/kotlin/org/modelix/model/client2/ModelClientV2Test.kt b/model-client/src/commonTest/kotlin/org/modelix/model/client2/ModelClientV2Test.kt index 54de639643..2ddcee6a7f 100644 --- a/model-client/src/commonTest/kotlin/org/modelix/model/client2/ModelClientV2Test.kt +++ b/model-client/src/commonTest/kotlin/org/modelix/model/client2/ModelClientV2Test.kt @@ -33,7 +33,7 @@ class ModelClientV2Test { val exception = assertFailsWith { modelClient.init() } - assertEquals("Parent job is Completed", exception.message) + assertEquals("Already closed", exception.message) } @Test diff --git a/model-client/src/jsMain/kotlin/org/modelix/model/client2/Closable.js.kt b/model-client/src/jsMain/kotlin/org/modelix/model/client2/Closable.js.kt index f7d062cf3d..7b7c1e1957 100644 --- a/model-client/src/jsMain/kotlin/org/modelix/model/client2/Closable.js.kt +++ b/model-client/src/jsMain/kotlin/org/modelix/model/client2/Closable.js.kt @@ -1,5 +1,5 @@ package org.modelix.model.client2 -internal actual interface Closable { +actual interface Closable { actual fun close() } diff --git a/model-client/src/jsMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt b/model-client/src/jsMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt index debd8aede4..5930136f64 100644 --- a/model-client/src/jsMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt +++ b/model-client/src/jsMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt @@ -2,11 +2,12 @@ package org.modelix.model.client2 import io.ktor.client.HttpClient import io.ktor.client.engine.js.Js +import org.modelix.model.oauth.ITokenParameters actual class ModelClientV2PlatformSpecificBuilder : ModelClientV2Builder() { - actual override fun createHttpClient(): HttpClient { + actual override fun createHttpClient(tokenParameters: ITokenParameters): HttpClient { return HttpClient(Js) { - configureHttpClient(this) + configureHttpClient(this, tokenParameters) } } } diff --git a/model-client/src/jsMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt b/model-client/src/jsMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt index f877a2f486..a474c6f17a 100644 --- a/model-client/src/jsMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt +++ b/model-client/src/jsMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt @@ -10,8 +10,8 @@ actual class ModelixAuthClient { authConfig: IAuthConfig, ) { when (authConfig) { - is OAuthConfig -> UnsupportedOperationException("JS client doesn't support OAuth2") - is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig.provider) + is OAuthConfig -> throw UnsupportedOperationException("JS client doesn't support OAuth2") + is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig) } } } diff --git a/model-client/src/jvmMain/kotlin/org/modelix/model/client/RestWebModelClient.kt b/model-client/src/jvmMain/kotlin/org/modelix/model/client/RestWebModelClient.kt index 794647b2ef..db12df5211 100644 --- a/model-client/src/jvmMain/kotlin/org/modelix/model/client/RestWebModelClient.kt +++ b/model-client/src/jvmMain/kotlin/org/modelix/model/client/RestWebModelClient.kt @@ -46,7 +46,6 @@ import org.modelix.model.KeyValueStoreCache import org.modelix.model.api.IIdGenerator import org.modelix.model.lazy.IDeserializingKeyValueStore import org.modelix.model.lazy.createObjectStoreCache -import org.modelix.model.oauth.ModelixAuthClient import org.modelix.model.persistent.HashUtil import org.modelix.model.sleep import org.modelix.model.util.StreamUtils.toStream @@ -149,13 +148,12 @@ class RestWebModelClient @JvmOverloads constructor( }, ) } - val modelixAuthClient = ModelixAuthClient() install(Auth) { bearer { loadTokens { val tp = authTokenProvider if (tp == null) { - modelixAuthClient.getTokens()?.let { BearerTokens(it.accessToken, it.refreshToken) } + null } else { val token = tp() if (token == null) { diff --git a/model-client/src/jvmMain/kotlin/org/modelix/model/client2/Closable.jvm.kt b/model-client/src/jvmMain/kotlin/org/modelix/model/client2/Closable.jvm.kt index 55f26eaa0a..ccde390c41 100644 --- a/model-client/src/jvmMain/kotlin/org/modelix/model/client2/Closable.jvm.kt +++ b/model-client/src/jvmMain/kotlin/org/modelix/model/client2/Closable.jvm.kt @@ -1,5 +1,5 @@ package org.modelix.model.client2 -internal actual interface Closable : java.io.Closeable { +actual interface Closable : java.io.Closeable { actual override fun close() } diff --git a/model-client/src/jvmMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt b/model-client/src/jvmMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt index 03870b6662..ef3ada7095 100644 --- a/model-client/src/jvmMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt +++ b/model-client/src/jvmMain/kotlin/org/modelix/model/client2/ModelClientV2PlatformSpecificBuilder.kt @@ -2,11 +2,12 @@ package org.modelix.model.client2 import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO +import org.modelix.model.oauth.ITokenParameters actual class ModelClientV2PlatformSpecificBuilder : ModelClientV2Builder() { - actual override fun createHttpClient(): HttpClient { + actual override fun createHttpClient(tokenParameters: ITokenParameters): HttpClient { return HttpClient(CIO) { - configureHttpClient(this) + configureHttpClient(this, tokenParameters) } } } diff --git a/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt b/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt index 69692165dd..117ad14587 100644 --- a/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt +++ b/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt @@ -33,6 +33,7 @@ import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import org.modelix.kotlin.utils.runSynchronized import org.modelix.kotlin.utils.urlEncode import java.net.SocketException import java.net.SocketTimeoutException @@ -44,16 +45,19 @@ actual class ModelixAuthClient { } private val httpTransport: HttpTransport = NetHttpTransport() private val jsonFactory: JsonFactory = GsonFactory() - private var lastCredentials: Credential? = null + private var lastCredentials = HashMap() - fun getTokens(): Credential? { - return lastCredentials?.takeIf { !it.isExpired() } + @Synchronized + fun getTokens(config: OAuthConfig): Credential? { + return lastCredentials[config.getCacheKey()]?.takeIf { !it.isExpired() } } + @Synchronized private fun Credential.isExpired(): Boolean { return (expiresInSeconds ?: return false) < 60 } + @Synchronized private fun Credential.refreshIfExpired(): Credential? { return if (isExpired()) { alwaysRefresh() @@ -62,6 +66,7 @@ actual class ModelixAuthClient { } } + @Synchronized fun Credential.alwaysRefresh(): Credential? { for (attempt in 1..3) { try { @@ -77,12 +82,13 @@ actual class ModelixAuthClient { return null } - suspend fun getAndMaybeRefreshTokens(): Credential? { - return lastCredentials?.refreshIfExpired() + @Synchronized + fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? { + return lastCredentials[config.getCacheKey()]?.refreshIfExpired() } suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? { - return lastCredentials?.alwaysRefresh() ?: authorize(config) + return runSynchronized(this) { lastCredentials[config.getCacheKey()]?.alwaysRefresh() } ?: authorize(config) } suspend fun authorize(config: OAuthConfig): Credential? { @@ -110,7 +116,9 @@ actual class ModelixAuthClient { val tokens = cancelable({ receiver.stop() }) { AuthorizationCodeInstalledApp(flow, receiver, browser).authorize(null) } - lastCredentials = tokens + runSynchronized(this@ModelixAuthClient) { + lastCredentials[config.getCacheKey()] = tokens + } return@withContext tokens } catch (ex: SocketException) { LOG.info("Port $port already in use. Trying next one.") @@ -127,7 +135,7 @@ actual class ModelixAuthClient { authConfig: IAuthConfig, ) { when (authConfig) { - is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig.provider) + is TokenProviderAuthConfig -> installAuthWithAuthTokenProvider(config, authConfig) is OAuthConfig -> installAuthWithPKCEFlow(config, authConfig) } } @@ -139,11 +147,19 @@ actual class ModelixAuthClient { var currentAuthConfig = initialAuthConfig fun String.fillParameters(): String { - return if (initialAuthConfig.repositoryId == null) { - this - } else { - replace("{repositoryId}", initialAuthConfig.repositoryId.id.urlEncode()) + val tokenParameters = currentAuthConfig.tokenParameters ?: return this + var result = this + if (result.contains("{repositoryId}")) { + result = result.replace("{repositoryId}", tokenParameters.getRepositoryId().orEmpty().urlEncode()) + } + if (result.contains("{branchName}")) { + result = result.replace("{branchName}", tokenParameters.getBranchName().orEmpty().urlEncode()) } + return result + } + + fun OAuthConfig.fillParameters(): OAuthConfig { + return copy(tokenUrl = tokenUrl?.fillParameters(), authorizationUrl = authorizationUrl?.fillParameters()) } config.apply { @@ -152,7 +168,7 @@ actual class ModelixAuthClient { loadTokens { // A potentially expired token is already refreshed here to avoid a 401 response. // When a 401 response is received, we always (re-)execute the PKCE flow. - getAndMaybeRefreshTokens()?.let { BearerTokens(it.accessToken, it.refreshToken) } + getAndMaybeRefreshTokens(currentAuthConfig)?.let { BearerTokens(it.accessToken, it.refreshToken) } } refreshTokens { try { @@ -181,7 +197,7 @@ actual class ModelixAuthClient { LOG.warn { "No client ID configured" } return@refreshTokens null } - val tokens = refreshTokensOrReauthorize(currentAuthConfig) + val tokens = refreshTokensOrReauthorize(currentAuthConfig.fillParameters()) checkNotNull(tokens) { "No tokens received" } LOG.info("Access Token: " + tokens.accessToken) diff --git a/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelClientV2JvmTest.kt b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelClientV2JvmTest.kt index 4ee47a3ec3..8a2f002d70 100644 --- a/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelClientV2JvmTest.kt +++ b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/ModelClientV2JvmTest.kt @@ -29,12 +29,12 @@ class ModelClientV2JvmTest { val firstException = assertFailsWith { modelClient.init() } - assertEquals("Parent job is Completed", firstException.message) + assertEquals("Already closed", firstException.message) // `Closable` implies that `.close` method is idempotent. modelClient.close() val secondException = assertFailsWith { modelClient.init() } - assertEquals("Parent job is Completed", secondException.message) + assertEquals("Already closed", secondException.message) } } diff --git a/model-server/src/test/kotlin/org/modelix/model/server/ModelClientV2Test.kt b/model-server/src/test/kotlin/org/modelix/model/server/ModelClientV2Test.kt index b83587fe4d..72f5c2742e 100644 --- a/model-server/src/test/kotlin/org/modelix/model/server/ModelClientV2Test.kt +++ b/model-server/src/test/kotlin/org/modelix/model/server/ModelClientV2Test.kt @@ -464,7 +464,7 @@ class ModelClientV2Test { } suspend fun sendRequest(mapper: suspend (HttpResponse) -> R): R { - return modelClient.httpClient.prepareGet { + return modelClient.httpClientProvider.getHttpClient(branch).prepareGet { url { takeFrom(modelClient.baseUrl) appendPathSegments("repositories", branch.repositoryId.id, "branches", branch.branchName) diff --git a/model-server/src/test/kotlin/org/modelix/model/server/TokenManagementTest.kt b/model-server/src/test/kotlin/org/modelix/model/server/TokenManagementTest.kt new file mode 100644 index 0000000000..e688f84d2b --- /dev/null +++ b/model-server/src/test/kotlin/org/modelix/model/server/TokenManagementTest.kt @@ -0,0 +1,202 @@ +package org.modelix.model.server + +import com.auth0.jwt.JWT +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.request.get +import io.ktor.http.buildUrl +import io.ktor.http.parameters +import io.ktor.http.takeFrom +import io.ktor.server.application.install +import io.ktor.server.engine.embeddedServer +import io.ktor.server.netty.Netty +import io.ktor.server.response.respond +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import mu.KotlinLogging +import org.modelix.authorization.IModelixAuthorizationConfig +import org.modelix.authorization.ModelixAuthorization +import org.modelix.authorization.ModelixJWTUtil +import org.modelix.authorization.createModelixAccessToken +import org.modelix.model.client2.ModelClientV2 +import org.modelix.model.lazy.RepositoryId +import org.modelix.model.oauth.IAuthConfig +import org.modelix.model.oauth.IAuthRequestHandler +import org.modelix.model.oauth.ITokenParameters +import org.modelix.model.oauth.ITokenProvider +import org.modelix.model.server.handlers.IdsApiImpl +import org.modelix.model.server.handlers.ModelReplicationServer +import org.modelix.model.server.handlers.RepositoriesManager +import org.modelix.model.server.store.InMemoryStoreClient +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals + +private val LOG = KotlinLogging.logger { } + +@Serializable +private data class TokenResponse( + @SerialName("access_token") + val accessToken: String, +) + +class TokenManagementTest { + + private fun runWithInProcessServer(block: suspend ApplicationTestBuilder.() -> Unit) = testApplication { + application { + try { + val modelixAuthorizationConfig: IModelixAuthorizationConfig.() -> Unit = { + permissionSchema = ModelServerPermissionSchema.SCHEMA + hmac512Key = "my-hmac-key" + } + install(ModelixAuthorization, modelixAuthorizationConfig) + installDefaultServerPlugins() + val repoManager = RepositoriesManager(InMemoryStoreClient()) + ModelReplicationServer(repoManager).init(this) + IdsApiImpl(repoManager).init(this) + } catch (ex: Throwable) { + LOG.error("", ex) + } + } + block() + } + + @Test + fun `separate token is used for each repository when using token provider`() = runWithInProcessServer { + val createdTokens = ArrayList() + + fun getPermissions(index: Int) = ModelixJWTUtil().extractPermissions(JWT.decode(createdTokens[index])) + + val client = ModelClientV2.builder() + .url("http://localhost/v2") + .client(client) + .authToken(object : ITokenProvider { + override suspend fun getToken(parameters: ITokenParameters): String { + return createModelixAccessToken( + hmac512key = "my-hmac-key", + user = "token-test@modelix.org", + grantedPermissions = listOfNotNull( + parameters.getRepositoryId()?.let { repository -> + ModelServerPermissionSchema.repository(repository).write.fullId + }, + parameters.getBranchName()?.let { branchName -> + ModelServerPermissionSchema.repository(parameters.getRepositoryId()!!).branch(branchName).write.fullId + }, + ), + ).also { createdTokens += it } + } + }) + .build() + assertEquals(0, createdTokens.size) + client.init() + assertEquals(1, createdTokens.size) + assertEquals(listOf(), getPermissions(0)) + + val repoId1 = RepositoryId("repo1") + client.initRepository(repoId1) + assertEquals(2, createdTokens.size) + assertEquals(listOf(ModelServerPermissionSchema.repository("repo1").write.fullId), getPermissions(1)) + + val repoId2 = RepositoryId("repo2") + client.initRepository(repoId2) + assertEquals(3, createdTokens.size) + assertEquals(listOf(ModelServerPermissionSchema.repository("repo2").write.fullId), getPermissions(2)) + } + + @Test + fun `separate token is used for each repository when using token endpoint`() = runTest { + val createdTokens = ArrayList() + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + // real server need instead of ktor.test because the PKCE flow is implemented by a non-ktor client + val server = embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + try { + val modelixAuthorizationConfig: IModelixAuthorizationConfig.() -> Unit = { + permissionSchema = ModelServerPermissionSchema.SCHEMA + hmac512Key = "my-hmac-key" + } + install(ModelixAuthorization, modelixAuthorizationConfig) + installDefaultServerPlugins() + val repoManager = RepositoriesManager(InMemoryStoreClient()) + ModelReplicationServer(repoManager).init(this) + IdsApiImpl(repoManager).init(this) + } catch (ex: Throwable) { + LOG.error("", ex) + } + routing { + post("/token") { + val repositoryId = call.queryParameters["repository-id"]?.takeIf { it.isNotEmpty() } + val branchName = call.queryParameters["branch-name"]?.takeIf { it.isNotEmpty() } + val token = createModelixAccessToken( + hmac512key = "my-hmac-key", + user = "token-test@modelix.org", + grantedPermissions = listOfNotNull( + repositoryId?.let { repository -> + ModelServerPermissionSchema.repository(repository).write.fullId + }, + branchName?.let { branchName -> + ModelServerPermissionSchema.repository(repositoryId!!).branch(branchName).write.fullId + }, + ), + ).also { createdTokens += it } + call.respond(TokenResponse(accessToken = token)) + } + } + }.startSuspend() + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + fun getPermissions(index: Int) = ModelixJWTUtil().extractPermissions(JWT.decode(createdTokens[index])) + + runWithServer { port -> + val client = ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + authRequestHandler(object : IAuthRequestHandler { + override fun browse(url: String) { + // https://localhost/realms/modelix/protocol/openid-connect/auth?client_id=my-client-id&code_challenge=YzBhqU2-lRzCkoSLVc0BGN3_AlwU5YUpYS1_m_6FMbI&code_challenge_method=S256&redirect_uri=http://127.0.0.1:64186/Callback&response_type=code&scope=email + val redirectUri = io.ktor.http.Url(url).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { + HttpClient(CIO).get(callbackWithCode) + } + } + }) + clientId("my-client-id") + tokenUrl("http://localhost:$port/token?repository-id={repositoryId}") + authorizationUrl("http://localhost:$port/auth") + }, + ) + .build() + assertEquals(0, createdTokens.size) + client.init() + assertEquals(1, createdTokens.size) + assertEquals(listOf(), getPermissions(0)) + + val repoId1 = RepositoryId("repo1") + client.initRepository(repoId1) + assertEquals(2, createdTokens.size) + assertEquals(listOf(ModelServerPermissionSchema.repository("repo1").write.fullId), getPermissions(1)) + + val repoId2 = RepositoryId("repo2") + client.initRepository(repoId2) + assertEquals(3, createdTokens.size) + assertEquals(listOf(ModelServerPermissionSchema.repository("repo2").write.fullId), getPermissions(2)) + } + } +} diff --git a/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/IModelSyncService.kt b/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/IModelSyncService.kt index 0ba3dbdb3c..8aa5c412d2 100644 --- a/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/IModelSyncService.kt +++ b/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/IModelSyncService.kt @@ -11,6 +11,7 @@ import org.modelix.model.client2.IModelClientV2 import org.modelix.model.lazy.BranchReference import org.modelix.model.lazy.RepositoryId import org.modelix.model.mpsadapters.toModelix +import org.modelix.model.oauth.ITokenProvider import org.modelix.model.oauth.OAuthConfigBuilder import org.modelix.mps.multiplatform.model.MPSModuleReference import java.io.Closeable @@ -59,7 +60,9 @@ data class ModelServerConnectionProperties( ) interface IServerConnection : Closeable { + @Deprecated("Provide an ITokenProvider") fun setTokenProvider(tokenProvider: (suspend () -> String?)) + fun setTokenProvider(tokenProvider: ITokenProvider) fun configureOAuth(body: OAuthConfigBuilder.() -> Unit) fun activate() diff --git a/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/ModelSyncService.kt b/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/ModelSyncService.kt index 6b3fe79962..e6c1838698 100644 --- a/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/ModelSyncService.kt +++ b/mps-sync-plugin3/src/main/kotlin/org/modelix/mps/sync3/ModelSyncService.kt @@ -22,6 +22,7 @@ import org.modelix.model.client2.IModelClientV2 import org.modelix.model.lazy.BranchReference import org.modelix.model.mpsadapters.MPSProjectAsNode import org.modelix.model.oauth.IAuthConfig +import org.modelix.model.oauth.ITokenProvider import org.modelix.model.oauth.OAuthConfigBuilder import org.modelix.model.oauth.TokenProvider import org.modelix.mps.multiplatform.model.MPSModuleReference @@ -230,10 +231,15 @@ class ModelSyncService(val project: Project) : } inner class Connection(val connection: AppLevelModelSyncService.ServerConnection) : IServerConnection { + @Deprecated("Provide an ITokenProvider") override fun setTokenProvider(tokenProvider: TokenProvider) { connection.setAuthorizationConfig(IAuthConfig.fromTokenProvider(tokenProvider)) } + override fun setTokenProvider(tokenProvider: ITokenProvider) { + connection.setAuthorizationConfig(IAuthConfig.fromTokenProvider(tokenProvider)) + } + override fun configureOAuth(body: OAuthConfigBuilder.() -> Unit) { connection.configureOAuth(body) }