From a4dfb4da3e89b44f313b6969fa519d9815b8c2ba Mon Sep 17 00:00:00 2001 From: Pablo Ridolfi Date: Mon, 16 Mar 2026 18:23:21 +0100 Subject: [PATCH] Add Groq as optional AI provider (try GROQ_API_KEY first, else Gemini) - Add backends module with Gemini and Groq implementations; server uses a single generate(prompt, system_instruction, temperature) API - Prefer Groq when GROQ_API_KEY is set, otherwise use Gemini (GEMINI_API_KEY). No SSHQ_PROVIDER; selection is by which key is set - Require openai dependency for Groq (OpenAI-compatible API) - CLI: accept either GROQ_API_KEY or GEMINI_API_KEY; error if neither set - Document both providers and env vars (SSHQ_GEMINI_MODEL, SSHQ_GROQ_MODEL) - Tests: mock backend in server tests; single CLI test for missing both keys Signed-off-by: Pablo Ridolfi --- README.md | 28 +++++++++++++++++------- pyproject.toml | 3 ++- src/sshq/backends.py | 51 ++++++++++++++++++++++++++++++++++++++++++++ src/sshq/cli.py | 4 ++-- src/sshq/server.py | 39 +++++++++------------------------ tests/test_cli.py | 7 +++--- tests/test_server.py | 30 +++++++++++--------------- 7 files changed, 101 insertions(+), 61 deletions(-) create mode 100644 src/sshq/backends.py diff --git a/README.md b/README.md index 36312ae..ac8d7e6 100644 --- a/README.md +++ b/README.md @@ -16,14 +16,16 @@ When working on embedded boards (like Yocto or Buildroot builds), you often face ## How it Works -1. **The Host Server:** When you run `sshq`, it spins up a lightweight local web server in the background on your laptop. This server securely holds your `GEMINI_API_KEY` and talks to the Gemini API. +1. **The Host Server:** When you run `sshq`, it spins up a lightweight local web server in the background on your laptop. This server holds your API key and talks to Groq if `GROQ_API_KEY` is set, otherwise to Gemini. 2. **The Reverse Tunnel:** `sshq` wraps your standard `ssh` command and adds a reverse port forward to a random local port, creating a secure tunnel from the board back to your laptop. 3. **Transparent Injection:** During login, `sshq` passes a Python one-liner to the board (the `q` client script) and drops it into `~/.local/bin/q`, and immediately hands you an interactive shell. ## Prerequisites * Python 3.9 or higher (on your host machine). -* A Gemini API key (get one from Google AI Studio or Google Cloud console). +* An API key for at least one supported AI provider: + * **Groq** (free tier): get a key from [Groq Console](https://console.groq.com/). If set, `GROQ_API_KEY` is used first. + * **Gemini** (default otherwise): get a key from [Google AI Studio](https://aistudio.google.com/) or Google Cloud console. * Python 3 installed on the target embedded board (standard library only; no external packages required). ## Installation @@ -37,11 +39,19 @@ pip install git+https://github.com/pridolfi/sshq.git (Note: You can also clone the repo and use `pip install -e .` if you plan to modify the code). ## Usage -1. Export your API key in your terminal (or add it to your `~/.bashrc` / `~/.zshrc`): +1. Export your API key in your terminal (or add it to your `~/.bashrc` / `~/.zshrc`). If `GROQ_API_KEY` is set it is used; otherwise `GEMINI_API_KEY` is required. -```bash -export GEMINI_API_KEY="your_api_key_here" -``` + **Groq** (free tier): + + ```bash + export GROQ_API_KEY="your_groq_api_key_here" + ``` + + **Gemini** (used when `GROQ_API_KEY` is not set): + + ```bash + export GEMINI_API_KEY="your_gemini_api_key_here" + ``` 2. Connect to your board exactly as you normally would, just replace ssh with sshq: @@ -117,5 +127,7 @@ CPU Features: | Variable | Required | Default | Description | |----------|----------|---------|-------------| -| `GEMINI_API_KEY` | Yes | — | Your Gemini API key. Used by the local server to call the Gemini API. | -| `SSHQ_GEMINI_MODEL` | No | `gemini-2.5-flash` | Gemini model used for command suggestions. You can also use `gemini-2.5-flash-lite`, which typically offers a higher quota. | +| `GROQ_API_KEY` | No (tried first) | — | Your Groq API key (free at [console.groq.com](https://console.groq.com)). If set, Groq is used. | +| `GEMINI_API_KEY` | Yes (if Groq not set) | — | Your Gemini API key. | +| `SSHQ_GEMINI_MODEL` | No | `gemini-2.5-flash` | Gemini model (e.g. `gemini-2.5-flash-lite` for higher quota). | +| `SSHQ_GROQ_MODEL` | No | `llama-3.3-70b-versatile` | Groq model (e.g. `llama-3.1-8b-instant` for faster replies). | diff --git a/pyproject.toml b/pyproject.toml index a2193e8..6967a8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ license = "MIT" requires-python = ">=3.9" dependencies = [ "flask", - "google-genai" + "google-genai", + "openai", ] authors = [ { name = "Pablo Ridolfi", email = "pabloridolfi@gmail.com" } diff --git a/src/sshq/backends.py b/src/sshq/backends.py new file mode 100644 index 0000000..0f41317 --- /dev/null +++ b/src/sshq/backends.py @@ -0,0 +1,51 @@ +"""AI provider backends for sshq. Each backend implements generate(prompt, system_instruction, temperature).""" +import os + + +def _gemini_generate(prompt, system_instruction, temperature=0.0): + from google import genai + from google.genai import types + + client = genai.Client() + model = os.environ.get("SSHQ_GEMINI_MODEL", "gemini-2.5-flash") + response = client.models.generate_content( + model=model, + contents=prompt, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=temperature, + ), + ) + return response.text.strip() + + +def _groq_generate(prompt, system_instruction, temperature=0.0): + from openai import OpenAI + + client = OpenAI( + base_url="https://api.groq.com/openai/v1", + api_key=os.environ.get("GROQ_API_KEY"), + ) + model = os.environ.get("SSHQ_GROQ_MODEL", "llama-3.3-70b-versatile") + # Groq converts temperature=0 to 1e-8; use a tiny value for deterministic output + t = max(1e-8, temperature) + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_instruction}, + {"role": "user", "content": prompt}, + ], + temperature=t, + ) + return (response.choices[0].message.content or "").strip() + + +def get_backend(): + """Return the active backend function (prompt, system_instruction, temperature=0.0) -> str. + Uses Groq if GROQ_API_KEY is set, otherwise Gemini (requires GEMINI_API_KEY). + """ + if os.environ.get("GROQ_API_KEY"): + return _groq_generate + if os.environ.get("GEMINI_API_KEY"): + return _gemini_generate + raise ValueError("Set GROQ_API_KEY or GEMINI_API_KEY.") diff --git a/src/sshq/cli.py b/src/sshq/cli.py index 6ef90a6..2be14e1 100644 --- a/src/sshq/cli.py +++ b/src/sshq/cli.py @@ -113,8 +113,8 @@ def main(): """ def main(): - if not os.environ.get("GEMINI_API_KEY"): - print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) + if not os.environ.get("GROQ_API_KEY") and not os.environ.get("GEMINI_API_KEY"): + print("Error: Set GROQ_API_KEY or GEMINI_API_KEY.", file=sys.stderr) sys.exit(1) prog = os.path.basename(sys.argv[0]) diff --git a/src/sshq/server.py b/src/sshq/server.py index 23ba75a..b6cf1a2 100644 --- a/src/sshq/server.py +++ b/src/sshq/server.py @@ -1,9 +1,8 @@ import logging -import os import flask.cli from flask import Flask, request, jsonify -from google import genai -from google.genai import types + +from .backends import get_backend # Suppress standard Werkzeug request logging log = logging.getLogger('werkzeug') @@ -13,7 +12,8 @@ flask.cli.show_server_banner = lambda *args: None app = Flask(__name__) -client = None # Initialized when the server starts +backend = None # Set to generate(prompt, system_instruction, temperature) in start_server + @app.route('/ask', methods=['POST']) def ask(): @@ -27,18 +27,9 @@ def ask(): "Do NOT use markdown formatting (like ```bash). Do NOT provide explanations." ) - model = os.environ.get("SSHQ_GEMINI_MODEL", "gemini-2.5-flash") - try: - response = client.models.generate_content( - model=model, - contents=data['prompt'], - config=types.GenerateContentConfig( - system_instruction=system_instruction, - temperature=0.0 - ) - ) - return jsonify({"command": response.text.strip()}) + text = backend(data['prompt'], system_instruction, temperature=0.0) + return jsonify({"command": text}) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -56,27 +47,17 @@ def analyze(): "You can use bullets and numbered lists to format the answer, in plain ASCII." ) - model = os.environ.get("SSHQ_GEMINI_MODEL", "gemini-2.5-flash") - # Combine content and user question so the model has full context contents = f"Content to analyze:\n\n{data['content']}\n\nUser question: {data['prompt']}" try: - response = client.models.generate_content( - model=model, - contents=contents, - config=types.GenerateContentConfig( - system_instruction=system_instruction, - temperature=0.0, - ), - ) - return jsonify({"analysis": response.text.strip()}) + text = backend(contents, system_instruction, temperature=0.0) + return jsonify({"analysis": text}) except Exception as e: return jsonify({"error": str(e)}), 500 def start_server(port): - global client - # Client automatically picks up the GEMINI_API_KEY environment variable - client = genai.Client() + global backend + backend = get_backend() app.run(port=port, host='127.0.0.1', debug=False) diff --git a/tests/test_cli.py b/tests/test_cli.py index 53f14f4..693acb3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -51,9 +51,8 @@ def test_version_exits_zero_and_prints_version(argv): assert err == "" -def test_missing_gemini_api_key_exits_nonzero_and_prints_to_stderr(): - env = {k: v for k, v in os.environ.items() if k != "GEMINI_API_KEY"} +def test_missing_both_api_keys_exits_nonzero_and_prints_to_stderr(): + env = {k: v for k, v in os.environ.items() if k not in ("GEMINI_API_KEY", "GROQ_API_KEY")} code, out, err = run_main(["user@host"], env=env, clear_env=True) assert code != 0 - assert "GEMINI_API_KEY" in err - assert "not set" in err + assert "GROQ_API_KEY" in err and "GEMINI_API_KEY" in err diff --git a/tests/test_server.py b/tests/test_server.py index 0f2e41b..224af55 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,10 +16,10 @@ def client(app): @pytest.fixture(autouse=True) -def mock_genai_client(): - """Mock the genai client so we never call the real API.""" +def mock_backend(): + """Mock the AI backend so we never call real APIs.""" mock = MagicMock() - with patch("sshq.server.client", mock): + with patch("sshq.server.backend", mock): yield mock @@ -35,19 +35,17 @@ def test_ask_without_prompt_returns_400(client): assert r.status_code == 400 -def test_ask_with_prompt_returns_command(client, mock_genai_client): - mock_response = MagicMock() - mock_response.text = " ls -la\n" - mock_genai_client.models.generate_content.return_value = mock_response +def test_ask_with_prompt_returns_command(client, mock_backend): + mock_backend.return_value = "ls -la" r = client.post("/ask", json={"prompt": "list files"}) assert r.status_code == 200 assert r.json == {"command": "ls -la"} - mock_genai_client.models.generate_content.assert_called_once() + mock_backend.assert_called_once() -def test_ask_on_api_error_returns_500(client, mock_genai_client): - mock_genai_client.models.generate_content.side_effect = RuntimeError("API error") +def test_ask_on_api_error_returns_500(client, mock_backend): + mock_backend.side_effect = RuntimeError("API error") r = client.post("/ask", json={"prompt": "do something"}) assert r.status_code == 500 @@ -70,10 +68,8 @@ def test_analyze_without_prompt_or_content_returns_400(client): assert r.status_code == 400 -def test_analyze_with_prompt_and_content_returns_analysis(client, mock_genai_client): - mock_response = MagicMock() - mock_response.text = "I see 2 failures in the log." - mock_genai_client.models.generate_content.return_value = mock_response +def test_analyze_with_prompt_and_content_returns_analysis(client, mock_backend): + mock_backend.return_value = "I see 2 failures in the log." r = client.post( "/analyze", @@ -81,11 +77,11 @@ def test_analyze_with_prompt_and_content_returns_analysis(client, mock_genai_cli ) assert r.status_code == 200 assert r.json == {"analysis": "I see 2 failures in the log."} - mock_genai_client.models.generate_content.assert_called_once() + mock_backend.assert_called_once() -def test_analyze_on_api_error_returns_500(client, mock_genai_client): - mock_genai_client.models.generate_content.side_effect = RuntimeError("API error") +def test_analyze_on_api_error_returns_500(client, mock_backend): + mock_backend.side_effect = RuntimeError("API error") r = client.post( "/analyze",