Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ repositories {

dependencies {
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
implementation "org.mongodb:mongodb-driver-sync:4.11.1"
implementation "org.mongodb:mongodb-driver-sync:5.6.1"
implementation group: 'org.jetbrains', name: 'annotations', version: '15.0'
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation group: 'org.graalvm.polyglot', name: 'polyglot', version: '25.0.1'
implementation group: 'org.graalvm.js', name: 'js', version: '25.0.1'
implementation files('libs/JMongosh-0.9.1.jar')
implementation group: 'com.nimbusds', name: 'oauth2-oidc-sdk', version: '11.23.1'
implementation group: 'org.graalvm.polyglot', name: 'polyglot', version: '25.0.1'
testImplementation group: 'junit', name: 'junit', version: '4.13.1'
testImplementation group: 'commons-io', name: 'commons-io', version: '2.7'
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/dbschema/MongoJdbcDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public Connection connect(String url, Properties info) throws SQLException {
synchronized (this) {
ShellHolder shellHolder = this.shellHolder;
this.shellHolder = createShellHolder();

return new MongoConnection(url, info, username, password, fetchDocumentsForMeta, shellHolder);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

public class DriverPropertyInfoHelper {
public static final String AUTH_MECHANISM = "authMechanism";
public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256"};
public static final String[] AUTH_MECHANISM_CHOICES = new String[]{"GSSAPI", "MONGODB-AWS", "MONGODB-X509", "PLAIN", "SCRAM-SHA-1", "SCRAM-SHA-256", "MONGODB-OIDC"};
public static final String AUTH_SOURCE = "authSource";
public static final String AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN";
public static final String SERVICE_NAME = "SERVICE_NAME";
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/com/dbschema/mongo/MongoClientWrapper.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.dbschema.mongo;

import com.dbschema.mongo.oidc.OidcCallback;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCredential;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoDatabase;
Expand Down Expand Up @@ -34,8 +36,11 @@ public MongoClientWrapper(@NotNull String uri, @NotNull Properties prop, @Nullab
automaticEncoding = Boolean.parseBoolean(prop.getProperty(ENCODE_CREDENTIALS));
}

uri = insertCredentials(uri, username, password, automaticEncoding);
uri = insertAuthMechanism(uri, prop.getProperty(AUTH_MECHANISM));
String authMechanism = prop.getProperty(AUTH_MECHANISM);
if (!"MONGODB-OIDC".equals(authMechanism)) {
uri = insertCredentials(uri, username, password, automaticEncoding);
}
uri = insertAuthMechanism(uri, authMechanism);
uri = insertAuthSource(uri, prop.getProperty(AUTH_SOURCE));
uri = insertAuthProperty(uri, AWS_SESSION_TOKEN, prop.getProperty(AWS_SESSION_TOKEN));
uri = insertAuthProperty(uri, SERVICE_NAME, prop.getProperty(SERVICE_NAME));
Expand All @@ -55,6 +60,15 @@ else if (canonicalizeHostName != null) {
MongoClientSettings.Builder builder = MongoClientSettings.builder()
.applyConnectionString(connectionString)
.applyToConnectionPoolSettings(b -> b.maxSize(maxPoolSize));

if ("MONGODB-OIDC".equals(authMechanism)) {
MongoCredential credential =
MongoCredential.createOidcCredential(null)
.withMechanismProperty(
MongoCredential.OIDC_HUMAN_CALLBACK_KEY, new OidcCallback());
builder.credential(credential);
}

String application = prop.getProperty(APPLICATION_NAME);
if (!isNullOrEmpty(application)) {
builder.applicationName(application);
Expand Down Expand Up @@ -99,6 +113,7 @@ else if (canonicalizeHostName != null) {
int timeout = Integer.parseInt(prop.getProperty(CONNECT_TIMEOUT, CONNECT_TIMEOUT_DEFAULT));
builder.applyToSocketSettings(b -> b.connectTimeout(timeout, TimeUnit.MILLISECONDS));
}

this.mongoClient = MongoClients.create(builder.build());
}
catch (Exception e) {
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/dbschema/mongo/MongoConnection.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.dbschema.mongo;

import com.dbschema.mongo.mongosh.MongoshScriptEngine;
import com.dbschema.mongo.mongosh.PrecalculatingShellHolder;
import com.dbschema.mongo.mongosh.ShellHolder;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand Down
292 changes: 292 additions & 0 deletions src/main/java/com/dbschema/mongo/oidc/OidcAuthFlow.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
package com.dbschema.mongo.oidc;

import com.mongodb.MongoCredential.IdpInfo;
import com.mongodb.MongoCredential.OidcCallbackContext;
import com.mongodb.MongoCredential.OidcCallbackResult;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.RefreshTokenGrant;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.oauth2.sdk.token.RefreshToken;
import com.nimbusds.oauth2.sdk.token.Tokens;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;

import java.awt.Desktop;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.RefreshFailedException;

public class OidcAuthFlow {

private static final Logger logger = Logger.getLogger(OidcAuthFlow.class.getName());
private static final String OFFLINE_ACCESS = "offline_access";
private static final String OPENID = "openid";

public Scope buildScopes(String clientID, IdpInfo idpServerInfo, OIDCProviderMetadata providerMetadata) {
Set<String> scopes = new HashSet<>();
Scope supportedScopes = providerMetadata.getScopes();

scopes.add(OPENID);
scopes.add(OFFLINE_ACCESS);

List<String> requestedScopes = idpServerInfo.getRequestScopes();
if (requestedScopes != null) {
String clientIDDefault = clientID + "/.default";
if (requestedScopes.contains(clientIDDefault)) {
scopes.add(clientIDDefault);
}
if (supportedScopes != null) {
for (String scope : requestedScopes) {
if (supportedScopes.contains(scope)) {
scopes.add(scope);
}
else {
logger.warning(String.format("Scope '%s' is not supported", scope));
}
}
}
}

Scope finalScopes = new Scope();
for (String scope : scopes) {
finalScopes.add(new Scope.Value(scope));
}
return finalScopes;
}

public OidcCallbackResult doAuthCodeFlow(OidcCallbackContext callbackContext)
throws OidcTimeoutException {
IdpInfo idpServerInfo = callbackContext.getIdpInfo();
String clientID = idpServerInfo.getClientId();
String issuerURI = idpServerInfo.getIssuer();

if (!isValid(idpServerInfo, clientID, issuerURI)) {
throw new IllegalStateException("OIDC configuration is incomplete: missing IdpInfo, clientID, or issuerURI");
}

Server server = new Server();
try {
OIDCProviderMetadata providerMetadata =
OIDCProviderMetadata.resolve(new Issuer(issuerURI));
URI authorizationEndpoint = providerMetadata.getAuthorizationEndpointURI();
URI tokenEndpoint = providerMetadata.getTokenEndpointURI();
Scope requestedScopes = buildScopes(clientID, idpServerInfo, providerMetadata);

server.start();

URI redirectURI = new URI("http://localhost:" + server.getPort() + "/redirect");
State state = new State();
CodeVerifier codeVerifier = new CodeVerifier();

AuthorizationRequest request =
new AuthorizationRequest.Builder(
new ResponseType(ResponseType.Value.CODE),
new ClientID(clientID))
.scope(requestedScopes)
.redirectionURI(redirectURI)
.state(state)
.codeChallenge(codeVerifier, CodeChallengeMethod.S256)
.endpointURI(authorizationEndpoint)
.build();

URI authorizationURI = request.toURI();
if (authorizationURI == null) {
throw new IllegalStateException("Authorization request URI is null");
}

try {
openURL(authorizationURI);
}
catch (Exception e) {
throw new IllegalStateException("Failed to open the browser: " + e.getMessage(), e);
}

OidcResponse response = server.getOidcResponse(callbackContext.getTimeout());
if (response == null || !state.getValue().equals(response.getState())) {
throw new IllegalStateException("OIDC response is null or returned an invalid state");
}

AuthorizationCode code = new AuthorizationCode(response.getCode());
AuthorizationCodeGrant codeGrant =
new AuthorizationCodeGrant(code, redirectURI, codeVerifier);
TokenRequest tokenRequest =
new TokenRequest(tokenEndpoint, new ClientID(clientID), codeGrant);

HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send();
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse);
if (!tokenResponse.indicatesSuccess()) {
throw new IllegalStateException(String.format("Token request failed: %s", httpResponse.getBody()));
}

return buildCallbackResult((OIDCTokenResponse) tokenResponse, issuerURI, clientID);
}
catch (OidcTimeoutException e) {
throw e;
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("OIDC authentication interrupted", e);
}
catch (Exception e) {
throw new RuntimeException("Error during OIDC authentication: " + e.getMessage(), e);
}
finally {
server.stop();
}
}

public OidcCallbackResult doRefresh(OidcCallbackContext callbackContext, String refreshTokenValue)
throws RefreshFailedException {
IdpInfo idpServerInfo = callbackContext.getIdpInfo();
String clientID = idpServerInfo.getClientId();
String issuerURI = idpServerInfo.getIssuer();

if (!isValid(idpServerInfo, clientID, issuerURI)) {
return null;
}
try {
OIDCProviderMetadata providerMetadata =
OIDCProviderMetadata.resolve(new Issuer(issuerURI));
URI tokenEndpoint = providerMetadata.getTokenEndpointURI();

if (refreshTokenValue == null) {
throw new IllegalArgumentException("Refresh token is required");
}

RefreshTokenGrant refreshTokenGrant =
new RefreshTokenGrant(new RefreshToken(refreshTokenValue));
TokenRequest tokenRequest =
new TokenRequest(tokenEndpoint, new ClientID(clientID), refreshTokenGrant);
HTTPResponse httpResponse = tokenRequest.toHTTPRequest().send();

try {
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(httpResponse);
if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
String errorCode = errorResponse.getErrorObject() != null
? errorResponse.getErrorObject().getCode() : null;
String errorDescription = errorResponse.getErrorObject() != null
? errorResponse.getErrorObject().getDescription() : null;
throw new RefreshFailedException(
"Token refresh failed: code=" + errorCode + ", description=" + errorDescription);
}
return buildCallbackResult((OIDCTokenResponse) tokenResponse, issuerURI, clientID);
}
catch (ParseException e) {
throw new RefreshFailedException(
"Failed to parse server response: " + e.getMessage()
+ " [response=" + httpResponse.getBody() + "]");
}
}
catch (RefreshFailedException e) {
throw e;
}
catch (Exception e) {
logger.log(Level.SEVERE, "OpenID Connect: Error during token refresh. " + e.getMessage());
throw new RefreshFailedException("Token refresh failed: " + e.getMessage());
}
}

private boolean isValid(IdpInfo idpInfo, String clientID, String issuerURI) {
return idpInfo != null && clientID != null && !clientID.isEmpty() && issuerURI != null;
}

private OidcCallbackResult buildCallbackResult(
OIDCTokenResponse tokenResponse, String issuerURI, String clientID) {
Tokens tokens = tokenResponse.getOIDCTokens();
String accessToken = tokens.getAccessToken().getValue();
String refreshToken =
tokens.getRefreshToken() != null ? tokens.getRefreshToken().getValue() : null;
Duration expiresIn = Duration.ofSeconds(tokens.getAccessToken().getLifetime());

OidcCallbackResult result = new OidcCallbackResult(accessToken, expiresIn, refreshToken);
OidcTokenCache.put(issuerURI, clientID, result, expiresIn);
return result;
}

private static final int MAX_URI_LENGTH = 2048;
private static final Set<String> ALLOWED_SCHEMES = Set.of("https", "http");

/**
* Opens the specified URI in the default web browser.
*
* Tries {@link Desktop#browse(URI)} first. When running inside DataGrip's
* driver JVM the custom AWT toolkit's {@code isSupported()} probe calls
* {@code browse(null)}, which crashes on the IDE side. In that case we
* fall back to platform commands matching IntelliJ's own
* {@code BrowserLauncherAppless}.
*/
private void openURL(URI uri) throws IOException {
validateURI(uri);

try {
Desktop.getDesktop().browse(uri);
return;
}
catch (Exception e) {
logger.log(Level.WARNING, "Desktop.browse() failed, falling back to platform command", e);
}

String osName = System.getProperty("os.name", "").toLowerCase();
ProcessBuilder pb;
if (osName.contains("mac")) {
pb = new ProcessBuilder("open", uri.toString());
}
else if (osName.contains("windows")) {
pb = new ProcessBuilder("cmd.exe", "/c", "start", "\"\"", uri.toString());
}
else if (osName.contains("linux")) {
pb = new ProcessBuilder("xdg-open", uri.toString());
}
else {
throw new UnsupportedOperationException("Cannot open browser on " + osName);
}
pb.redirectErrorStream(true);
pb.start();
}

private static void validateURI(URI uri) {
String scheme = uri.getScheme();
if (scheme == null || !ALLOWED_SCHEMES.contains(scheme.toLowerCase())) {
throw new IllegalArgumentException("Refusing to open URI with scheme: " + scheme);
}

String host = uri.getHost();
if (host == null || host.isEmpty()) {
throw new IllegalArgumentException("URI must have a host");
}

if (uri.toString().length() > MAX_URI_LENGTH) {
throw new IllegalArgumentException("URI exceeds maximum length of " + MAX_URI_LENGTH);
}

if (uri.getUserInfo() != null) {
throw new IllegalArgumentException("URI must not contain user info");
}

String uriString = uri.toString();
if (uriString.contains("..") || uriString.contains("\n") || uriString.contains("\r")) {
throw new IllegalArgumentException("URI contains invalid characters");
}
}
}
Loading