@@ -22,7 +22,7 @@ import (
2222)
2323
2424const usageTemplate = `Usage:
25- %[2]s [-search-path "login,public"] [- debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
25+ %[2]s [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
2626
2727Notes:
2828 Flags must come before the DSN (standard Go flag parsing).
@@ -33,12 +33,11 @@ Flags:
3333
3434Examples:
3535 %[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
36- %[2]s -search-path "login,public" 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432'
3736 %[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
3837`
3938
4039func main () {
41- rawURL , searchPath , debugAWS , err := parseCLIArgs (os .Args [1 :], os .Args [0 ])
40+ rawURL , debugAWS , err := parseCLIArgs (os .Args [1 :], os .Args [0 ])
4241 if err != nil {
4342 if errors .Is (err , flag .ErrHelp ) {
4443 printUsage (os .Stdout , os .Args [0 ])
@@ -76,41 +75,14 @@ func main() {
7675 log .Fatalf ("failed to get connection string from provider: %v" , err )
7776 }
7877
79- parsedURL , err := url .Parse (dsnWithToken )
80- if err != nil {
81- log .Fatalf ("failed to parse connection string from provider: %v" , err )
82- }
83-
84- if err := addSearchPathToPSQLURL (parsedURL , searchPath ); err != nil {
85- fmt .Fprintf (os .Stderr , "%v\n " , err )
86- os .Exit (2 )
87- }
88-
89- password := ""
90- if parsedURL .User != nil {
91- var ok bool
92- password , ok = parsedURL .User .Password ()
93- if ok {
94- parsedURL .User = url .User (parsedURL .User .Username ())
95- }
96- }
97-
98- // Pass DSN to psql without password in argv, and provide password via env.
99- cmd := exec .Command ("psql" , parsedURL .String ())
78+ cmd := exec .Command ("psql" , dsnWithToken )
10079
10180 cmd .Stdin = os .Stdin
10281 cmd .Stdout = os .Stdout
10382 cmd .Stderr = os .Stderr
10483
105- env := os .Environ ()
106- if password != "" {
107- env = append (env , "PGPASSWORD=" + password )
108- }
109-
110- cmd .Env = env
111-
112- // Keep psql in the foreground process group. Swallow SIGINT in wrapper so
113- // psql handles Ctrl-C directly.
84+ // Ignore SIGINT in the wrapper so interactive Ctrl-C can be handled by psql.
85+ // Forward SIGTERM to the child process.
11486 sigCh := make (chan os.Signal , 1 )
11587 signal .Notify (sigCh , os .Interrupt , syscall .SIGTERM )
11688 defer signal .Stop (sigCh )
@@ -145,17 +117,16 @@ func main() {
145117 }
146118}
147119
148- func newFlagSet (bin string , output io.Writer ) (fs * flag.FlagSet , searchPathFlag * string , debugAWSFlag * bool ) {
120+ func newFlagSet (bin string , output io.Writer ) (fs * flag.FlagSet , debugAWSFlag * bool ) {
149121 fs = flag .NewFlagSet (bin , flag .ContinueOnError )
150122 fs .SetOutput (output )
151123
152124 return fs ,
153- fs .String ("search-path" , "" , "Optional PostgreSQL search_path to set (e.g. 'myschema,public')" ),
154125 fs .Bool ("debug-aws" , false , "Print AWS caller identity before connecting" )
155126}
156127
157128func printUsage (output io.Writer , bin string ) {
158- fs , _ , _ := newFlagSet (bin , io .Discard )
129+ fs , _ := newFlagSet (bin , io .Discard )
159130
160131 var defaults bytes.Buffer
161132 fs .SetOutput (& defaults )
@@ -164,63 +135,19 @@ func printUsage(output io.Writer, bin string) {
164135 fmt .Fprintf (output , usageTemplate , strings .TrimRight (defaults .String (), "\n " ), bin )
165136}
166137
167- func parseCLIArgs (args []string , bin string ) (rawURL string , searchPath string , debugAWS bool , err error ) {
168- fs , searchPathFlag , debugAWSFlag := newFlagSet (bin , io .Discard )
138+ func parseCLIArgs (args []string , bin string ) (rawURL string , debugAWS bool , err error ) {
139+ fs , debugAWSFlag := newFlagSet (bin , io .Discard )
169140
170141 if err := fs .Parse (args ); err != nil {
171- return "" , "" , false , err
142+ return "" , false , err
172143 }
173144
174145 positionals := fs .Args ()
175146 if len (positionals ) != 1 {
176- return "" , "" , false , fmt .Errorf ("expected exactly one positional RDS IAM connection URL argument, got %d" , len (positionals ))
147+ return "" , false , fmt .Errorf ("expected exactly one positional RDS IAM connection URL argument, got %d" , len (positionals ))
177148 }
178149
179- return positionals [0 ], * searchPathFlag , * debugAWSFlag , nil
180- }
181-
182- func addSearchPathToPSQLURL (u * url.URL , searchPath string ) error {
183- normalized , err := normalizeSearchPath (searchPath )
184- if err != nil {
185- return err
186- }
187- if normalized == "" {
188- return nil
189- }
190-
191- query := u .Query ()
192- add := "-csearch_path=" + normalized
193-
194- existing := strings .TrimSpace (query .Get ("options" ))
195- if existing == "" {
196- query .Set ("options" , add )
197- } else {
198- query .Set ("options" , existing + " " + add )
199- }
200-
201- u .RawQuery = query .Encode ()
202- return nil
203- }
204-
205- func normalizeSearchPath (searchPath string ) (string , error ) {
206- if strings .TrimSpace (searchPath ) == "" {
207- return "" , nil
208- }
209-
210- parts := strings .Split (searchPath , "," )
211- cleaned := make ([]string , 0 , len (parts ))
212- for _ , p := range parts {
213- p = strings .TrimSpace (p )
214- if p != "" {
215- cleaned = append (cleaned , p )
216- }
217- }
218-
219- if len (cleaned ) == 0 {
220- return "" , fmt .Errorf ("search path cannot be empty" )
221- }
222-
223- return strings .Join (cleaned , "," ), nil
150+ return positionals [0 ], * debugAWSFlag , nil
224151}
225152
226153func validateRDSIAMURL (rawURL string ) error {
@@ -234,6 +161,9 @@ func validateRDSIAMURL(rawURL string) error {
234161 if parsedURL .User == nil || strings .TrimSpace (parsedURL .User .Username ()) == "" {
235162 return fmt .Errorf ("connection URL must include a database username" )
236163 }
164+ if _ , ok := parsedURL .User .Password (); ok {
165+ return fmt .Errorf ("connection URL must not include a password for postgres+rds-iam" )
166+ }
237167 if strings .TrimSpace (parsedURL .Host ) == "" {
238168 return fmt .Errorf ("connection URL must include a database host" )
239169 }
@@ -249,6 +179,6 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error {
249179 return fmt .Errorf ("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w" , err )
250180 }
251181
252- fmt .Printf ( "Caller ARN: %s\n " , aws .ToString (out .Arn ))
182+ fmt .Fprintf ( os . Stderr , "Caller ARN: %s\n " , aws .ToString (out .Arn ))
253183 return nil
254184}
0 commit comments