diff --git a/modules/database/aws_iam_auth.go b/modules/database/aws_iam_auth.go index 72c1561b..231ab079 100644 --- a/modules/database/aws_iam_auth.go +++ b/modules/database/aws_iam_auth.go @@ -177,7 +177,13 @@ func extractEndpointFromDSN(dsn string) (string, error) { // Handle different DSN formats if strings.Contains(dsn, "://") { // URL-style DSN (e.g., postgres://user:password@host:port/database) - u, err := url.Parse(dsn) + // Handle potential special characters in password by preprocessing + preprocessedDSN, err := preprocessDSNForParsing(dsn) + if err != nil { + return "", fmt.Errorf("failed to preprocess DSN: %w", err) + } + + u, err := url.Parse(preprocessedDSN) if err != nil { return "", fmt.Errorf("failed to parse DSN URL: %w", err) } @@ -203,11 +209,71 @@ func extractEndpointFromDSN(dsn string) (string, error) { return "", ErrExtractEndpointFailed } +// preprocessDSNForParsing handles special characters in passwords by URL-encoding them +func preprocessDSNForParsing(dsn string) (string, error) { + // Find the pattern: ://username:password@host + protocolEnd := strings.Index(dsn, "://") + if protocolEnd == -1 { + return dsn, nil // Not a URL-style DSN + } + + // Find the start of credentials (after ://) + credentialsStart := protocolEnd + 3 + + // Find the end of credentials (before @host) + // We need to find the last @ that separates credentials from host + // Look for the pattern @host:port or @host/path + remainingDSN := dsn[credentialsStart:] + + // Find all @ characters + atIndices := []int{} + for i := 0; i < len(remainingDSN); i++ { + if remainingDSN[i] == '@' { + atIndices = append(atIndices, i) + } + } + + if len(atIndices) == 0 { + return dsn, nil // No credentials + } + + // Use the last @ as the separator between credentials and host + atIndex := atIndices[len(atIndices)-1] + + // Extract the credentials part + credentialsEnd := credentialsStart + atIndex + credentials := dsn[credentialsStart:credentialsEnd] + + // Find the colon that separates username from password + colonIndex := strings.Index(credentials, ":") + if colonIndex == -1 { + return dsn, nil // No password + } + + // Extract username and password + username := credentials[:colonIndex] + password := credentials[colonIndex+1:] + + // URL-encode the password + encodedPassword := url.QueryEscape(password) + + // Reconstruct the DSN with encoded password + encodedDSN := dsn[:credentialsStart] + username + ":" + encodedPassword + dsn[credentialsEnd:] + + return encodedDSN, nil +} + // replaceDSNPassword replaces the password in a DSN with the provided token func replaceDSNPassword(dsn, token string) (string, error) { if strings.Contains(dsn, "://") { // URL-style DSN - u, err := url.Parse(dsn) + // Handle potential special characters in password by preprocessing + preprocessedDSN, err := preprocessDSNForParsing(dsn) + if err != nil { + return "", fmt.Errorf("failed to preprocess DSN: %w", err) + } + + u, err := url.Parse(preprocessedDSN) if err != nil { return "", fmt.Errorf("failed to parse DSN URL: %w", err) } diff --git a/modules/database/aws_iam_auth_test.go b/modules/database/aws_iam_auth_test.go index b511fbf9..e20b4a99 100644 --- a/modules/database/aws_iam_auth_test.go +++ b/modules/database/aws_iam_auth_test.go @@ -114,6 +114,24 @@ func TestExtractEndpointFromDSN(t *testing.T) { expected: "mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432", wantErr: false, }, + { + name: "postgres URL style with special characters in password", + dsn: "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend", + expected: "some-dev-backend.cluster.us-east-1.rds.amazonaws.com", + wantErr: false, + }, + { + name: "postgres URL style with URL-encoded special characters in password", + dsn: "postgresql://someuser:8jKwouNHdI%21u6a%3Fkx%28UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend", + expected: "some-dev-backend.cluster.us-east-1.rds.amazonaws.com", + wantErr: false, + }, + { + name: "postgres URL style with complex special characters in password", + dsn: "postgres://user:p@ssw0rd!#$^&*()_+-=[]{}|;':\",./<>@host.example.com:5432/db", + expected: "host.example.com:5432", + wantErr: false, + }, { name: "invalid DSN", dsn: "invalid-dsn", @@ -167,6 +185,18 @@ func TestReplaceDSNPassword(t *testing.T) { expected: "host=localhost port=5432 user=postgres dbname=mydb password=test-iam-token", wantErr: false, }, + { + name: "postgres URL style with special characters in password", + dsn: "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend", + expected: "postgresql://someuser:test-iam-token@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend", + wantErr: false, + }, + { + name: "postgres URL style with complex special characters in password", + dsn: "postgres://user:p@ssw0rd!#$^&*()_+-=[]{}|;':\",./<>@host.example.com:5432/db", + expected: "postgres://user:test-iam-token@host.example.com:5432/db", + wantErr: false, + }, { name: "URL style without user info", dsn: "postgres://host:5432/mydb", diff --git a/modules/database/dsn_special_chars_test.go b/modules/database/dsn_special_chars_test.go new file mode 100644 index 00000000..d6118622 --- /dev/null +++ b/modules/database/dsn_special_chars_test.go @@ -0,0 +1,125 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSpecialCharacterPasswordDSNParsing tests the specific issue from the GitHub issue #19 +func TestSpecialCharacterPasswordDSNParsing(t *testing.T) { + // This is the exact DSN from the GitHub issue + issueExampleDSN := "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend" + + // Test that endpoint extraction works + endpoint, err := extractEndpointFromDSN(issueExampleDSN) + require.NoError(t, err) + require.Equal(t, "some-dev-backend.cluster.us-east-1.rds.amazonaws.com", endpoint) + + // Test that password replacement works + token := "test-iam-token" + newDSN, err := replaceDSNPassword(issueExampleDSN, token) + require.NoError(t, err) + require.Contains(t, newDSN, "postgresql://someuser:test-iam-token@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend") + + // Test that we can create a database service with this DSN (without actually connecting) + config := ConnectionConfig{ + Driver: "postgres", + DSN: issueExampleDSN, + } + + service, err := NewDatabaseService(config) + require.NoError(t, err) + require.NotNil(t, service) + + // Clean up + err = service.Close() + require.NoError(t, err) +} + +// TestSpecialCharacterPasswordDSNParsingWithAWSIAM tests the issue with AWS IAM auth +func TestSpecialCharacterPasswordDSNParsingWithAWSIAM(t *testing.T) { + // This is the exact DSN from the GitHub issue + issueExampleDSN := "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend" + + // Test that we can create a database service with AWS IAM auth enabled + config := ConnectionConfig{ + Driver: "postgres", + DSN: issueExampleDSN, + AWSIAMAuth: &AWSIAMAuthConfig{ + Enabled: true, + Region: "us-east-1", + DBUser: "someuser", + TokenRefreshInterval: 300, + }, + } + + // Skip this test if AWS credentials are not available + service, err := NewDatabaseService(config) + if err != nil { + // If AWS config loading fails, skip this test + if err.Error() == "failed to create AWS IAM token provider: failed to load AWS config: no EC2 IMDS role found, operation error ec2imds: GetMetadata, canceled, context canceled" { + t.Skip("AWS credentials not available, skipping test") + } + t.Fatalf("Failed to create service: %v", err) + } + require.NotNil(t, service) + + // Clean up + err = service.Close() + require.NoError(t, err) +} + +// TestEdgeCaseSpecialCharacterPasswords tests various edge cases +func TestEdgeCaseSpecialCharacterPasswords(t *testing.T) { + testCases := []struct { + name string + dsn string + expectedHost string + }{ + { + name: "password with @ symbol", + dsn: "postgres://user:pass@word@host.com:5432/db", + expectedHost: "host.com:5432", + }, + { + name: "password with multiple @ symbols", + dsn: "postgres://user:p@ss@w@rd@host.com:5432/db", + expectedHost: "host.com:5432", + }, + { + name: "password with query-like characters", + dsn: "postgres://user:pass?key=value&other=test@host.com:5432/db", + expectedHost: "host.com:5432", + }, + { + name: "password with URL-like structure", + dsn: "postgres://user:http://example.com/path?query=value@host.com:5432/db", + expectedHost: "host.com:5432", + }, + { + name: "password with colon", + dsn: "postgres://user:pass:word@host.com:5432/db", + expectedHost: "host.com:5432", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + endpoint, err := extractEndpointFromDSN(tc.dsn) + require.NoError(t, err) + require.Equal(t, tc.expectedHost, endpoint) + + // Test password replacement + token := "test-token" + newDSN, err := replaceDSNPassword(tc.dsn, token) + require.NoError(t, err) + require.Contains(t, newDSN, token) + + // Verify we can parse the new DSN + newEndpoint, err := extractEndpointFromDSN(newDSN) + require.NoError(t, err) + require.Equal(t, tc.expectedHost, newEndpoint) + }) + } +}