Skip to content

Commit da15554

Browse files
remove searchpath
1 parent edc2f00 commit da15554

2 files changed

Lines changed: 30 additions & 100 deletions

File tree

cmd/rds-iam-psql/README.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
A CLI that launches an interactive `psql` session from a required RDS IAM URL:
44
- positional `postgres+rds-iam://...` DSN
5-
- optional `-search-path` flag
65
- optional `-debug-aws` flag
76

87
## Why?
@@ -33,7 +32,7 @@ go build
3332
## Usage
3433

3534
```bash
36-
rds-iam-psql [-search-path "schema,public"] [-debug-aws] '<postgres+rds-iam-url>'
35+
rds-iam-psql [-debug-aws] '<postgres+rds-iam-url>'
3736
```
3837

3938
- Flags must come before the DSN (standard Go flag parsing behavior).
@@ -43,7 +42,6 @@ rds-iam-psql [-search-path "schema,public"] [-debug-aws] '<postgres+rds-iam-url>
4342

4443
| Flag | Default | Description |
4544
|------|---------|-------------|
46-
| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) |
4745
| `-debug-aws` | `false` | Print STS caller identity before connecting |
4846

4947
## Examples
@@ -60,14 +58,6 @@ IAM URL with cross-account role assumption:
6058
rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql'
6159
```
6260

63-
With search path:
64-
65-
```bash
66-
rds-iam-psql \
67-
-search-path "app_schema,public" \
68-
'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp'
69-
```
70-
7161
With AWS identity debugging:
7262

7363
```bash
@@ -80,13 +70,23 @@ Without explicit database name (defaults to username):
8070
rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432'
8171
```
8272

73+
## Changing Search Path In psql
74+
75+
If you need to change the schema search path, do it from the interactive `psql` session after connecting:
76+
77+
```sql
78+
SHOW search_path;
79+
SET search_path TO app_schema, public;
80+
```
81+
82+
This applies to the current session. If you need a persistent default, configure it in Postgres (for example with `ALTER ROLE ... SET search_path ...`).
83+
8384
## How It Works
8485

8586
1. Parses and validates the positional IAM URL.
8687
2. Builds a `pgutils` connection string provider from the IAM URL.
87-
3. If `-search-path` is set, adds libpq `options=-csearch_path=...` to the connection URI before launching `psql`.
88-
4. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN.
89-
5. Resolves an IAM tokenized DSN from the provider and launches `psql` with:
88+
3. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN.
89+
4. Resolves an IAM tokenized DSN from the provider and launches `psql` with:
9090
- `PGPASSWORD` set from the generated token
9191

9292
## Setting Up IAM Auth on RDS

cmd/rds-iam-psql/main.go

Lines changed: 16 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
)
2323

2424
const usageTemplate = `Usage:
25-
%[2]s [-search-path "login,public"] [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
25+
%[2]s [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
2626
2727
Notes:
2828
Flags must come before the DSN (standard Go flag parsing).
@@ -33,12 +33,11 @@ Flags:
3333
3434
Examples:
3535
%[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
36-
%[2]s -search-path "login,public" 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432'
3736
%[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
3837
`
3938

4039
func main() {
41-
rawURL, searchPath, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0])
40+
rawURL, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0])
4241
if err != nil {
4342
if errors.Is(err, flag.ErrHelp) {
4443
printUsage(os.Stdout, os.Args[0])
@@ -76,41 +75,14 @@ func main() {
7675
log.Fatalf("failed to get connection string from provider: %v", err)
7776
}
7877

79-
parsedURL, err := url.Parse(dsnWithToken)
80-
if err != nil {
81-
log.Fatalf("failed to parse connection string from provider: %v", err)
82-
}
83-
84-
if err := addSearchPathToPSQLURL(parsedURL, searchPath); err != nil {
85-
fmt.Fprintf(os.Stderr, "%v\n", err)
86-
os.Exit(2)
87-
}
88-
89-
password := ""
90-
if parsedURL.User != nil {
91-
var ok bool
92-
password, ok = parsedURL.User.Password()
93-
if ok {
94-
parsedURL.User = url.User(parsedURL.User.Username())
95-
}
96-
}
97-
98-
// Pass DSN to psql without password in argv, and provide password via env.
99-
cmd := exec.Command("psql", parsedURL.String())
78+
cmd := exec.Command("psql", dsnWithToken)
10079

10180
cmd.Stdin = os.Stdin
10281
cmd.Stdout = os.Stdout
10382
cmd.Stderr = os.Stderr
10483

105-
env := os.Environ()
106-
if password != "" {
107-
env = append(env, "PGPASSWORD="+password)
108-
}
109-
110-
cmd.Env = env
111-
112-
// Keep psql in the foreground process group. Swallow SIGINT in wrapper so
113-
// psql handles Ctrl-C directly.
84+
// Ignore SIGINT in the wrapper so interactive Ctrl-C can be handled by psql.
85+
// Forward SIGTERM to the child process.
11486
sigCh := make(chan os.Signal, 1)
11587
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
11688
defer signal.Stop(sigCh)
@@ -145,17 +117,16 @@ func main() {
145117
}
146118
}
147119

148-
func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, searchPathFlag *string, debugAWSFlag *bool) {
120+
func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, debugAWSFlag *bool) {
149121
fs = flag.NewFlagSet(bin, flag.ContinueOnError)
150122
fs.SetOutput(output)
151123

152124
return fs,
153-
fs.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')"),
154125
fs.Bool("debug-aws", false, "Print AWS caller identity before connecting")
155126
}
156127

157128
func printUsage(output io.Writer, bin string) {
158-
fs, _, _ := newFlagSet(bin, io.Discard)
129+
fs, _ := newFlagSet(bin, io.Discard)
159130

160131
var defaults bytes.Buffer
161132
fs.SetOutput(&defaults)
@@ -164,63 +135,19 @@ func printUsage(output io.Writer, bin string) {
164135
fmt.Fprintf(output, usageTemplate, strings.TrimRight(defaults.String(), "\n"), bin)
165136
}
166137

167-
func parseCLIArgs(args []string, bin string) (rawURL string, searchPath string, debugAWS bool, err error) {
168-
fs, searchPathFlag, debugAWSFlag := newFlagSet(bin, io.Discard)
138+
func parseCLIArgs(args []string, bin string) (rawURL string, debugAWS bool, err error) {
139+
fs, debugAWSFlag := newFlagSet(bin, io.Discard)
169140

170141
if err := fs.Parse(args); err != nil {
171-
return "", "", false, err
142+
return "", false, err
172143
}
173144

174145
positionals := fs.Args()
175146
if len(positionals) != 1 {
176-
return "", "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals))
147+
return "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals))
177148
}
178149

179-
return positionals[0], *searchPathFlag, *debugAWSFlag, nil
180-
}
181-
182-
func addSearchPathToPSQLURL(u *url.URL, searchPath string) error {
183-
normalized, err := normalizeSearchPath(searchPath)
184-
if err != nil {
185-
return err
186-
}
187-
if normalized == "" {
188-
return nil
189-
}
190-
191-
query := u.Query()
192-
add := "-csearch_path=" + normalized
193-
194-
existing := strings.TrimSpace(query.Get("options"))
195-
if existing == "" {
196-
query.Set("options", add)
197-
} else {
198-
query.Set("options", existing+" "+add)
199-
}
200-
201-
u.RawQuery = query.Encode()
202-
return nil
203-
}
204-
205-
func normalizeSearchPath(searchPath string) (string, error) {
206-
if strings.TrimSpace(searchPath) == "" {
207-
return "", nil
208-
}
209-
210-
parts := strings.Split(searchPath, ",")
211-
cleaned := make([]string, 0, len(parts))
212-
for _, p := range parts {
213-
p = strings.TrimSpace(p)
214-
if p != "" {
215-
cleaned = append(cleaned, p)
216-
}
217-
}
218-
219-
if len(cleaned) == 0 {
220-
return "", fmt.Errorf("search path cannot be empty")
221-
}
222-
223-
return strings.Join(cleaned, ","), nil
150+
return positionals[0], *debugAWSFlag, nil
224151
}
225152

226153
func validateRDSIAMURL(rawURL string) error {
@@ -234,6 +161,9 @@ func validateRDSIAMURL(rawURL string) error {
234161
if parsedURL.User == nil || strings.TrimSpace(parsedURL.User.Username()) == "" {
235162
return fmt.Errorf("connection URL must include a database username")
236163
}
164+
if _, ok := parsedURL.User.Password(); ok {
165+
return fmt.Errorf("connection URL must not include a password for postgres+rds-iam")
166+
}
237167
if strings.TrimSpace(parsedURL.Host) == "" {
238168
return fmt.Errorf("connection URL must include a database host")
239169
}
@@ -249,6 +179,6 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error {
249179
return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err)
250180
}
251181

252-
fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn))
182+
fmt.Fprintf(os.Stderr, "Caller ARN: %s\n", aws.ToString(out.Arn))
253183
return nil
254184
}

0 commit comments

Comments
 (0)