forked from jdevoo/gen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcmd.go
More file actions
312 lines (285 loc) · 8.92 KB
/
cmd.go
File metadata and controls
312 lines (285 loc) · 8.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"strings"
"syscall"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
// Version information, populated by make
// Token count accumulator in case of CTRL-C
var (
version string
golang string
githash string
tokenCount int32
)
// Usage overrides PrintDefaults to provide custom usage information.
func emitUsage(out io.Writer) {
fmt.Fprintf(out, "Usage: "+filepath.Base(os.Args[0])+" [options] <prompt>\n")
fmt.Fprintf(out, "\n")
fmt.Fprintf(out, "Command-line interface to Google Gemini large language models\n")
fmt.Fprintf(out, " Requires a valid GEMINI_API_KEY environment variable set.\n")
fmt.Fprintf(out, " Content is generated by a prompt and an optional system instruction.\n")
fmt.Fprintf(out, " Additionally, supports stdin and .prompt files as valid prompt parts.\n")
fmt.Fprintf(out, "\n")
fmt.Fprintf(out, "Options:\n")
flag.PrintDefaults()
}
func emitGen(in io.Reader, out io.Writer) int {
var err error
// Check for API key
if val, ok := os.LookupEnv("GEMINI_API_KEY"); !ok || len(val) == 0 {
fmt.Fprintf(out, "Environment variable GEMINI_API_KEY not set!\n")
return 1
}
// Flag handling
verboseFlag := flag.Bool("V", false, "output model details and chat history\ndetails include model name | maxInputTokens | maxOutputTokens | temp | top_p | top_k")
chatModeFlag := flag.Bool("c", false, "enter chat mode after content generation\ntype two consecutive blank lines to exit\nnot supported on windows when stdin used")
filePathVal := flag.String("f", "", "attach file to prompt where string is the path to the file\nfile with extension .prompt treated as user prompt or system instruction")
helpFlag := flag.Bool("h", false, "show this help message and exit")
jsonFlag := flag.Bool("json", false, "response in JavaScript Object Notation")
modelName := flag.String("m", "gemini-1.5-flash", "generative model name")
keyVals := ParamMap{}
flag.Var(&keyVals, "p", "zero or more prompt parameter values in format key=val\nreplaces all occurrences of {key} in prompt with val")
systemInstructionFlag := flag.Bool("s", false, "treat first of stdin, file option or argument as system instruction")
tokenCountFlag := flag.Bool("t", false, "output total number of tokens")
tempVal := flag.Float64("temp", 1.0, "changes sampling during response generation [0.0,2.0]")
toolFlag := flag.Bool("tool", false, fmt.Sprintf("invoke one of the tools {%s}", knownTools()))
topPVal := flag.Float64("top_p", 0.95, "changes how the model selects tokens for generation [0.0,1.0]")
unsafeFlag := flag.Bool("unsafe", false, "force generation when gen aborts with FinishReasonSafety")
versionFlag := flag.Bool("v", false, "show version and exit")
flag.Parse()
// Handle version flag
if *versionFlag {
fmt.Fprintf(out, "gen version %s (%s %s)\n", version, golang, githash)
return 0
}
// Set stdin as prompt, if provided
var prompt []genai.Part
stdinFlag := hasInputFromStdin(in)
if stdinFlag {
if data, err := io.ReadAll(in); err != nil {
log.Fatal(err)
} else {
prompt = append(prompt, genai.Text(searchReplace(string(data), keyVals)))
stdinFlag = len(prompt) > 0
}
}
// Handle invalid argument and option combinations
if *helpFlag ||
// temp out of range
(*tempVal < 0 || *tempVal > 2) ||
// topP out of range
(*topPVal < 0 || *topPVal > 1) ||
// no prompt as stdin, argument or file
(!stdinFlag && len(flag.Args()) == 0 && path.Ext(*filePathVal) != ".prompt") ||
(runtime.GOOS == "windows" && stdinFlag && *chatModeFlag) ||
(!*chatModeFlag && *systemInstructionFlag &&
// no chat mode, stdin as system instruction, no prompt
((stdinFlag && len(flag.Args()) == 0 && path.Ext(*filePathVal) != ".prompt") ||
// no chat mode, argument as system instruction, no prompt
(!stdinFlag && len(flag.Args()) > 0 && path.Ext(*filePathVal) != ".prompt") ||
// no chat mode, file as system instruction, no prompt
(!stdinFlag && len(flag.Args()) == 0 && path.Ext(*filePathVal) == ".prompt"))) ||
// chat mode, file is not prompt, stdin or argument as system instruction, no prompt
(*chatModeFlag && *systemInstructionFlag && *filePathVal != "" && path.Ext(*filePathVal) != ".prompt" &&
((!stdinFlag || len(flag.Args()) > 0) || (stdinFlag || len(flag.Args()) == 0))) {
emitUsage(out)
return 1
}
// Create a genai client
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY")))
if err != nil {
genLogFatal(err)
}
defer client.Close()
model := client.GenerativeModel(*modelName)
// Set temperature and top_p from args or model defaults
model.SetTemperature(float32(*tempVal))
model.SetTopP(float32(*topPVal))
// Handle json flag
if *jsonFlag {
model.ResponseMIMEType = "application/json"
}
// Handle unsafe flag
if *unsafeFlag {
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockNone,
},
}
}
// Handle tool flag registering tools declared in the tools.go file
if *toolFlag {
registerTools(model, genai.FunctionCallingAny)
} else {
registerTools(model, genai.FunctionCallingNone)
}
// Promote stdin prompt as system instruction
if stdinFlag && *systemInstructionFlag {
model.SystemInstruction = &genai.Content{
Parts: prompt,
}
prompt = nil
}
// Handle file option and set as prompt or system instruction if file ends with .prompt
var file *genai.File
if *filePathVal != "" {
f, err := os.Open(*filePathVal)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if path.Ext(*filePathVal) == ".prompt" {
if data, err := io.ReadAll(f); err != nil {
log.Fatal(err)
} else {
prompt = append(prompt, genai.Text(searchReplace(string(data), keyVals)))
if !stdinFlag && len(flag.Args()) > 0 && *systemInstructionFlag {
model.SystemInstruction = &genai.Content{
Parts: prompt,
}
prompt = nil
}
}
} else {
file, err = uploadFile(ctx, client, *filePathVal)
if err != nil {
genLogFatal(err)
}
defer func() {
err := client.DeleteFile(ctx, file.Name)
if err != nil {
genLogFatal(err)
}
}()
if err != nil {
genLogFatal(err)
}
}
}
// Handle argument as prompt
if len(flag.Args()) > 0 {
prompt = append(prompt, genai.Text(searchReplace(strings.Join(flag.Args(), " "), keyVals)))
if path.Ext(*filePathVal) != ".prompt" && *chatModeFlag && *systemInstructionFlag { // argument as system instruction for chat
model.SystemInstruction = &genai.Content{
Parts: prompt[len(prompt)-1:],
}
prompt = prompt[:len(prompt)-1]
}
}
// Send FileData to model if available
if file != nil {
prompt = append(prompt, genai.FileData{MIMEType: file.MIMEType, URI: file.URI})
if err != nil {
genLogFatal(err)
}
}
// Handle verbose flag and output model information
if *verboseFlag {
info, err := model.Info(ctx)
if err != nil {
genLogFatal(err)
}
fmt.Fprintf(out, "\033[36m%s | %d | %d | %.2f | %.2f | %d\033[0m\n", info.Name, info.InputTokenLimit, info.OutputTokenLimit, *tempVal, *topPVal, info.TopK)
}
// Start chat session
sess := model.StartChat()
tty := in
// Set file descriptor for chat input
if stdinFlag && *chatModeFlag {
tty, err = os.Open("/dev/tty")
if err != nil {
log.Fatal(err)
}
}
// Main chat loop
for {
if len(prompt) > 0 {
iter := sess.SendMessageStream(ctx, prompt...)
if *tokenCountFlag {
res, err := model.CountTokens(ctx, prompt...)
if err != nil {
genLogFatal(err)
}
tokenCount += res.TotalTokens
}
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
fmt.Fprintf(out, "\n")
genLogFatal(err)
}
emitGeneratedResponse(resp, out)
}
}
if !*chatModeFlag {
break
}
if *verboseFlag {
for i, c := range sess.History {
fmt.Fprintf(out, "\033[36m%02d: %+v\033[0m\n", i, c)
}
}
fmt.Fprintf(out, "\n")
input, err := readLine(tty)
if err != nil {
log.Fatal(err)
}
// Check for double blank line exit condition
if input == "" {
input, err = readLine(tty)
if err != nil {
log.Fatal(err)
}
if input == "" {
break // exit chat mode
}
}
prompt = []genai.Part{genai.Text(input)}
}
if *tokenCountFlag {
fmt.Fprintf(out, "\n\033[31m%d tokens\033[0m\n", tokenCount)
}
return 0
}
func main() {
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
go func() {
<-done
if tokenCount > 0 {
fmt.Printf("\n\033[31m%d tokens\033[0m\n", tokenCount)
}
os.Exit(1)
}()
os.Exit(emitGen(os.Stdin, os.Stdout))
}