Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions modules/database/aws_iam_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@
// Handle different DSN formats
if strings.Contains(dsn, "://") {
// URL-style DSN (e.g., postgres://user:password@host:port/database)
u, err := url.Parse(dsn)
// Handle potential special characters in password by preprocessing
preprocessedDSN, err := preprocessDSNForParsing(dsn)
if err != nil {
return "", fmt.Errorf("failed to preprocess DSN: %w", err)
}

Check warning on line 184 in modules/database/aws_iam_auth.go

View check run for this annotation

Codecov / codecov/patch

modules/database/aws_iam_auth.go#L183-L184

Added lines #L183 - L184 were not covered by tests

u, err := url.Parse(preprocessedDSN)
if err != nil {
return "", fmt.Errorf("failed to parse DSN URL: %w", err)
}
Expand All @@ -203,11 +209,71 @@
return "", ErrExtractEndpointFailed
}

// preprocessDSNForParsing handles special characters in passwords by URL-encoding them
func preprocessDSNForParsing(dsn string) (string, error) {
// Find the pattern: ://username:password@host
protocolEnd := strings.Index(dsn, "://")
if protocolEnd == -1 {
return dsn, nil // Not a URL-style DSN
}

Check warning on line 218 in modules/database/aws_iam_auth.go

View check run for this annotation

Codecov / codecov/patch

modules/database/aws_iam_auth.go#L217-L218

Added lines #L217 - L218 were not covered by tests

// Find the start of credentials (after ://)
credentialsStart := protocolEnd + 3

// Find the end of credentials (before @host)
// We need to find the last @ that separates credentials from host
// Look for the pattern @host:port or @host/path
remainingDSN := dsn[credentialsStart:]

// Find all @ characters
atIndices := []int{}
for i := 0; i < len(remainingDSN); i++ {
if remainingDSN[i] == '@' {
atIndices = append(atIndices, i)
}
}

if len(atIndices) == 0 {
return dsn, nil // No credentials
}

// Use the last @ as the separator between credentials and host
atIndex := atIndices[len(atIndices)-1]

// Extract the credentials part
credentialsEnd := credentialsStart + atIndex
credentials := dsn[credentialsStart:credentialsEnd]

// Find the colon that separates username from password
colonIndex := strings.Index(credentials, ":")
if colonIndex == -1 {
return dsn, nil // No password
}

Check warning on line 251 in modules/database/aws_iam_auth.go

View check run for this annotation

Codecov / codecov/patch

modules/database/aws_iam_auth.go#L250-L251

Added lines #L250 - L251 were not covered by tests

// Extract username and password
username := credentials[:colonIndex]
password := credentials[colonIndex+1:]

// URL-encode the password
encodedPassword := url.QueryEscape(password)

// Reconstruct the DSN with encoded password
encodedDSN := dsn[:credentialsStart] + username + ":" + encodedPassword + dsn[credentialsEnd:]

return encodedDSN, nil
}

// replaceDSNPassword replaces the password in a DSN with the provided token
func replaceDSNPassword(dsn, token string) (string, error) {
if strings.Contains(dsn, "://") {
// URL-style DSN
u, err := url.Parse(dsn)
// Handle potential special characters in password by preprocessing
preprocessedDSN, err := preprocessDSNForParsing(dsn)
if err != nil {
return "", fmt.Errorf("failed to preprocess DSN: %w", err)
}

Check warning on line 274 in modules/database/aws_iam_auth.go

View check run for this annotation

Codecov / codecov/patch

modules/database/aws_iam_auth.go#L273-L274

Added lines #L273 - L274 were not covered by tests

u, err := url.Parse(preprocessedDSN)
if err != nil {
return "", fmt.Errorf("failed to parse DSN URL: %w", err)
}
Expand Down
30 changes: 30 additions & 0 deletions modules/database/aws_iam_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,24 @@ func TestExtractEndpointFromDSN(t *testing.T) {
expected: "mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432",
wantErr: false,
},
{
name: "postgres URL style with special characters in password",
dsn: "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend",
expected: "some-dev-backend.cluster.us-east-1.rds.amazonaws.com",
wantErr: false,
},
{
name: "postgres URL style with URL-encoded special characters in password",
dsn: "postgresql://someuser:8jKwouNHdI%21u6a%3Fkx%28UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend",
expected: "some-dev-backend.cluster.us-east-1.rds.amazonaws.com",
wantErr: false,
},
{
name: "postgres URL style with complex special characters in password",
dsn: "postgres://user:p@ssw0rd!#$^&*()_+-=[]{}|;':\",./<>@host.example.com:5432/db",
expected: "host.example.com:5432",
wantErr: false,
},
{
name: "invalid DSN",
dsn: "invalid-dsn",
Expand Down Expand Up @@ -167,6 +185,18 @@ func TestReplaceDSNPassword(t *testing.T) {
expected: "host=localhost port=5432 user=postgres dbname=mydb password=test-iam-token",
wantErr: false,
},
{
name: "postgres URL style with special characters in password",
dsn: "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend",
expected: "postgresql://someuser:test-iam-token@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend",
wantErr: false,
},
{
name: "postgres URL style with complex special characters in password",
dsn: "postgres://user:p@ssw0rd!#$^&*()_+-=[]{}|;':\",./<>@host.example.com:5432/db",
expected: "postgres://user:test-iam-token@host.example.com:5432/db",
wantErr: false,
},
{
name: "URL style without user info",
dsn: "postgres://host:5432/mydb",
Expand Down
125 changes: 125 additions & 0 deletions modules/database/dsn_special_chars_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package database

import (
"testing"

"github.com/stretchr/testify/require"
)

// TestSpecialCharacterPasswordDSNParsing tests the specific issue from the GitHub issue #19
func TestSpecialCharacterPasswordDSNParsing(t *testing.T) {
// This is the exact DSN from the GitHub issue
issueExampleDSN := "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend"

// Test that endpoint extraction works
endpoint, err := extractEndpointFromDSN(issueExampleDSN)
require.NoError(t, err)
require.Equal(t, "some-dev-backend.cluster.us-east-1.rds.amazonaws.com", endpoint)

// Test that password replacement works
token := "test-iam-token"
newDSN, err := replaceDSNPassword(issueExampleDSN, token)
require.NoError(t, err)
require.Contains(t, newDSN, "postgresql://someuser:test-iam-token@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend")

// Test that we can create a database service with this DSN (without actually connecting)
config := ConnectionConfig{
Driver: "postgres",
DSN: issueExampleDSN,
}

service, err := NewDatabaseService(config)
require.NoError(t, err)
require.NotNil(t, service)

// Clean up
err = service.Close()
require.NoError(t, err)
}

// TestSpecialCharacterPasswordDSNParsingWithAWSIAM tests the issue with AWS IAM auth
func TestSpecialCharacterPasswordDSNParsingWithAWSIAM(t *testing.T) {
// This is the exact DSN from the GitHub issue
issueExampleDSN := "postgresql://someuser:8jKwouNHdI!u6a?kx(UuQ-Bgm34P@some-dev-backend.cluster.us-east-1.rds.amazonaws.com/some_backend"

// Test that we can create a database service with AWS IAM auth enabled
config := ConnectionConfig{
Driver: "postgres",
DSN: issueExampleDSN,
AWSIAMAuth: &AWSIAMAuthConfig{
Enabled: true,
Region: "us-east-1",
DBUser: "someuser",
TokenRefreshInterval: 300,
},
}

// Skip this test if AWS credentials are not available
service, err := NewDatabaseService(config)
if err != nil {
// If AWS config loading fails, skip this test
if err.Error() == "failed to create AWS IAM token provider: failed to load AWS config: no EC2 IMDS role found, operation error ec2imds: GetMetadata, canceled, context canceled" {
t.Skip("AWS credentials not available, skipping test")
}
t.Fatalf("Failed to create service: %v", err)
}
require.NotNil(t, service)

// Clean up
err = service.Close()
require.NoError(t, err)
}

// TestEdgeCaseSpecialCharacterPasswords tests various edge cases
func TestEdgeCaseSpecialCharacterPasswords(t *testing.T) {
testCases := []struct {
name string
dsn string
expectedHost string
}{
{
name: "password with @ symbol",
dsn: "postgres://user:pass@word@host.com:5432/db",
expectedHost: "host.com:5432",
},
{
name: "password with multiple @ symbols",
dsn: "postgres://user:p@ss@w@rd@host.com:5432/db",
expectedHost: "host.com:5432",
},
{
name: "password with query-like characters",
dsn: "postgres://user:pass?key=value&other=test@host.com:5432/db",
expectedHost: "host.com:5432",
},
{
name: "password with URL-like structure",
dsn: "postgres://user:http://example.com/path?query=value@host.com:5432/db",
expectedHost: "host.com:5432",
},
{
name: "password with colon",
dsn: "postgres://user:pass:word@host.com:5432/db",
expectedHost: "host.com:5432",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
endpoint, err := extractEndpointFromDSN(tc.dsn)
require.NoError(t, err)
require.Equal(t, tc.expectedHost, endpoint)

// Test password replacement
token := "test-token"
newDSN, err := replaceDSNPassword(tc.dsn, token)
require.NoError(t, err)
require.Contains(t, newDSN, token)

// Verify we can parse the new DSN
newEndpoint, err := extractEndpointFromDSN(newDSN)
require.NoError(t, err)
require.Equal(t, tc.expectedHost, newEndpoint)
})
}
}