-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcmd.go
More file actions
265 lines (236 loc) · 8.91 KB
/
cmd.go
File metadata and controls
265 lines (236 loc) · 8.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
package main
import (
"context"
"flag"
"fmt"
"io"
"os"
"os/signal"
"runtime/debug"
"strings"
"sync/atomic"
"syscall"
"time"
"google.golang.org/genai"
)
// Version and Githash are populated by make
var (
Version string
Githash string
TokenCount atomic.Int32
)
const (
SPExt = ".sprompt" // system prompt extension
PExt = ".prompt" // regular prompt extension
DigestKey = "{digest}" // key to replace with embedded content
DotGen = ".gen" // name of chat history file
DotGenRc = ".genrc" // name of preferences file
)
// Parameters holds gen flag values as well as Args and MCP sessions.
type Parameters struct {
Args []string // non-flag command-line arguments i.e. prompt
ChatMode bool
CodeGen bool
DigestPaths ParamArray // RAG
Embed bool // RAG
EmbModel string
FilePaths ParamArray
GenModel string
GoogleSearch bool
Help bool
ImgModality bool
JSON bool
K int
Lambda float64
MCPServers ParamArray
MCPSessions SessionArray
OnlyKvs bool // RAG
Interactive bool // terminal session?
Segment bool // SegmentForeground by default
SegmentBackground bool
SegModel string
SystemInstruction bool
TokenCount bool
Temp float64
ThinkingLevel genai.ThinkingLevel
Timeout time.Duration
Tool bool
ToolRegistry ToolMap
TopP float64
Unsafe bool
Verbose bool
Version bool
Walk bool // used with FilePaths
}
func main() {
// create context that listens for OS interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
if err := run(ctx); err != nil {
if err.Error() != "" {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
os.Exit(1)
}
}
func run(ctx context.Context) error {
params := &Parameters{}
keyVals := ParamMap{}
if err := parseFlags(params, &keyVals); err != nil {
return err
}
if params.Version {
printVersion()
return nil
}
defer cleanup(params)
// store keyVals and params in context
ctx = context.WithValue(ctx, "keyVals", keyVals)
ctx = context.WithValue(ctx, "params", params)
// stash MCP client sessions in params.MCPSessionsa
if params.Help || params.Tool {
if err := initMCPSessions(ctx, params); err != nil {
return fmt.Errorf("MCP error: %w", err)
}
}
// handle help and version flags before any further processing
// context includes params with list to known tools
if params.Help {
emitUsage(ctx, os.Stdout, true)
return nil
}
// argument validation
if err := isArgsInvalid(params, keyVals); err != nil {
fmt.Fprintf(os.Stderr, "%v\n\n", err)
emitUsage(ctx, os.Stdout, false)
return fmt.Errorf("")
}
if err := validateEnv(); err != nil {
return fmt.Errorf("Environment error: %w", err)
}
if err := genContent(ctx, os.Stdin, os.Stdout); err != nil {
return fmt.Errorf("Generation error: %w", err)
}
return nil
}
// parseFlags handles flag definitions and parameter map for variable substitutions in prompts.
func parseFlags(params *Parameters, keyVals *ParamMap) error {
// default parameter values
params.K = 3
params.Lambda = 0.5
params.Temp = 1.0
params.TopP = 0.95
params.ThinkingLevel = genai.ThinkingLevelUnspecified
params.Timeout = 300 * time.Second
params.EmbModel = "gemini-embedding-001"
params.GenModel = "gemini-2.5-flash"
params.SegModel = "image-segmentation-001"
if err := loadPrefs(params); err != nil {
return fmt.Errorf("Error loading preferences from %s: %v\n", DotGenRc, err)
}
flag.BoolVar(¶ms.Verbose, "V", false, "output model details, system instructions, chat history and thoughts")
flag.BoolVar(¶ms.SegmentBackground, "b", false, "background segmentation mode (default: foreground)")
flag.BoolVar(¶ms.ChatMode, "c", false, "enter chat mode (incompatible with -json, -img, -code or -g)")
flag.BoolVar(¶ms.CodeGen, "code", false, "code execution tool (incompatible with -g, -json, -img or -tool)")
flag.Var(¶ms.DigestPaths, "d", "path to a digest folder")
flag.BoolVar(¶ms.Embed, "e", false, fmt.Sprintf("write text embeddings to digest (default model \"%s\")", params.EmbModel))
flag.Var(¶ms.FilePaths, "f", "GCS URI, file, directory or quoted pattern of files to attach")
flag.BoolVar(¶ms.GoogleSearch, "g", false, "Google search tool (incompatible with -code, -json, -img and -tool)")
flag.BoolVar(¶ms.Help, "h", false, "show available tools, this help message and exit")
flag.BoolVar(¶ms.OnlyKvs, "i", false, "only store metadata with embeddings and ignore the content")
flag.BoolVar(¶ms.ImgModality, "img", false, "generate jpeg images (use -m to set a supported model)")
flag.BoolVar(¶ms.JSON, "json", false, "structured output (incompatible with -g, -code, -img and -tool)")
flag.IntVar(¶ms.K, "k", params.K, "maximum number of entries from digest to retrieve")
flag.Float64Var(¶ms.Lambda, "l", params.Lambda, "balance accuracy and diversity querying digests [0.0,1.0]")
flag.Func("think", fmt.Sprintf("%s, %s, %s or %s (default: %s)",
genai.ThinkingLevelMinimal,
genai.ThinkingLevelLow,
genai.ThinkingLevelMedium,
genai.ThinkingLevelHigh,
params.ThinkingLevel), func(val string) error {
params.ThinkingLevel = genai.ThinkingLevel(strings.ToUpper(val))
return nil
})
flag.StringVar(¶ms.GenModel, "m", params.GenModel, "model name")
flag.Var(¶ms.MCPServers, "mcp", "mcp stdio or streamable server command")
flag.Var(keyVals, "p", "prompt parameter value in format key=val")
flag.BoolVar(¶ms.Walk, "r", false, "process directory declared with -f recursively")
flag.BoolVar(¶ms.SystemInstruction, "s", false, "treat argument as system prompt")
flag.BoolVar(¶ms.Segment, "seg", false, fmt.Sprintf("segment image on VertexAI (default model \"%s\")", params.SegModel))
flag.BoolVar(¶ms.TokenCount, "t", false, "output total number of tokens")
flag.Float64Var(¶ms.Temp, "temp", params.Temp, "sampling during response generation [0.0,2.0]")
flag.DurationVar(¶ms.Timeout, "timeout", params.Timeout, "time limit for single turn content generation")
flag.BoolVar(¶ms.Tool, "tool", false, "invoke one of the tools (incompatible with -s, -g, -json, -img or -code)")
flag.Float64Var(¶ms.TopP, "top_p", params.TopP, "how the model selects tokens for generation [0.0,1.0]")
flag.BoolVar(¶ms.Unsafe, "unsafe", false, "force generation when gen aborts with FinishReasonSafety")
flag.BoolVar(¶ms.Version, "v", false, "show version and exit")
flag.Parse()
params.Args = flag.Args()
params.Interactive = !isRedirected(os.Stdin)
params.ToolRegistry = ToolMap{}
return nil
}
func printVersion() {
var genaiVer, mcpVer string
if binfo, ok := debug.ReadBuildInfo(); ok {
for _, dep := range binfo.Deps {
switch dep.Path {
case "google.golang.org/genai":
genaiVer = dep.Version
case "github.com/modelcontextprotocol/go-sdk":
mcpVer = dep.Version
}
}
}
fmt.Printf("gen %s (%s sdk %s mcp %s)\n", Version, Githash, genaiVer, mcpVer)
}
// emitUsage overrides PrintDefaults to provide custom usage information.
func emitUsage(ctx context.Context, out io.Writer, emitTools bool) {
var tools string
if emitTools {
tools, _ = knownTools(ctx)
fmt.Fprintln(out, "Command-line interface to Google Gemini large language models")
fmt.Fprintln(out, " Requires a valid GOOGLE_API_KEY environment variable set.")
fmt.Fprintln(out, " VertexAI backend with valid GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION.")
fmt.Fprintln(out, " Content is generated by a prompt and optional system instructions.")
fmt.Fprintln(out, " Use - to assign stdin as prompt or as attached file.")
fmt.Fprintf(out, "\n")
}
fmt.Fprintln(out, "Usage: gen [options] <prompt>")
fmt.Fprintf(out, "\n")
if emitTools {
fmt.Fprintln(out, fmt.Sprintf("Tools:\n%s", tools))
fmt.Fprintf(out, "\n")
}
fmt.Fprintln(out, "Options:")
fmt.Fprintf(out, "\n")
flag.PrintDefaults()
}
// validateEnv checks for required Google Cloud/AI Studio credentials.
func validateEnv() error {
hasCloudProject := os.Getenv("GOOGLE_CLOUD_PROJECT") != ""
hasAPIKey := os.Getenv("GOOGLE_API_KEY") != ""
if !hasCloudProject && !hasAPIKey {
return fmt.Errorf("neither GOOGLE_CLOUD_PROJECT nor GOOGLE_API_KEY is set")
}
if hasCloudProject {
if os.Getenv("GOOGLE_CLOUD_LOCATION") == "" {
return fmt.Errorf("GOOGLE_CLOUD_LOCATION must be set when using GOOGLE_CLOUD_PROJECT")
}
if hasAPIKey && os.Getenv("GOOGLE_GENAI_USE_VERTEXAI") == "" {
return fmt.Errorf("set GOOGLE_GENAI_USE_VERTEXAI to 'true' or 'false' when both API Key and Project ID are present")
}
}
return nil
}
func cleanup(params *Parameters) {
for _, sess := range params.MCPSessions {
if sess != nil {
sess.Close()
}
}
// final token count report
if params.TokenCount {
fmt.Printf("\n"+tokens("%d tokens")+"\n", TokenCount.Load())
}
}