diff --git a/pyproject.toml b/pyproject.toml index f20f71b..a2193e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ sshq = "sshq.cli:main" [tool.pytest.ini_options] testpaths = ["tests"] +filterwarnings = [ + "ignore::DeprecationWarning:google.genai.*", +] [tool.ruff] target-version = "py39" diff --git a/src/sshq/cli.py b/src/sshq/cli.py index f120a5c..6ef90a6 100644 --- a/src/sshq/cli.py +++ b/src/sshq/cli.py @@ -16,8 +16,53 @@ def main(): if len(sys.argv) < 2: print("Usage: q ") + print(" q --analyze ") sys.exit(1) + # q --analyze + if len(sys.argv) >= 3 and sys.argv[1] == "--analyze": + filepath = sys.argv[2] + prompt = " ".join(sys.argv[3:]).strip() + if not prompt: + print("Usage: q --analyze ") + sys.exit(1) + try: + with open(filepath, encoding="utf-8", errors="replace") as f: + content = f.read() + except OSError as e: + print(f"Error: cannot read {{filepath}}: {{e}}") + sys.exit(1) + + data = json.dumps({{"prompt": prompt, "content": content}}).encode("utf-8") + req = urllib.request.Request( + "http://localhost:{port}/analyze", + data=data, + headers={{'Content-Type': 'application/json'}}, + ) + try: + with urllib.request.urlopen(req) as response: + result = json.loads(response.read().decode()) + if "error" in result: + print(f"Error: {{result['error']}}") + sys.exit(1) + print(result.get("analysis", "")) + except urllib.error.HTTPError as e: + try: + body = e.read().decode() + res = json.loads(body) + msg = res.get("error", body or e.reason) + except Exception: + msg = e.reason or str(e) + print(f"Error: {{msg}}") + sys.exit(1) + except urllib.error.URLError as e: + print("Error: Tunnel is down. Did you connect using sshq?") + if e.reason: + print(f" ({{e.reason}})") + sys.exit(1) + return + + # q -> suggest command prompt = " ".join(sys.argv[1:]) data = json.dumps({{"prompt": prompt}}).encode('utf-8') req = urllib.request.Request( diff --git a/src/sshq/server.py b/src/sshq/server.py index 58ff427..95f3cfb 100644 --- a/src/sshq/server.py +++ b/src/sshq/server.py @@ -42,6 +42,37 @@ def ask(): except Exception as e: return jsonify({"error": str(e)}), 500 + +@app.route('/analyze', methods=['POST']) +def analyze(): + data = request.json + if not data or 'prompt' not in data or 'content' not in data: + return jsonify({"error": "prompt and content are required"}), 400 + + system_instruction = ( + "You are an expert embedded Linux engineer analyzing text and log files. " + "Answer the user's question about the provided content clearly and concisely. " + "Do NOT use markdown formatting for the answer, and do NOT use markdown code fences for the content itself." + ) + + model = os.environ.get("SSHQ_GEMINI_MODEL", "gemini-2.5-flash") + # Combine content and user question so the model has full context + contents = f"Content to analyze:\n\n{data['content']}\n\nUser question: {data['prompt']}" + + try: + response = client.models.generate_content( + model=model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_instruction, + temperature=0.0, + ), + ) + return jsonify({"analysis": response.text.strip()}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + def start_server(port): global client # Client automatically picks up the GEMINI_API_KEY environment variable diff --git a/tests/test_server.py b/tests/test_server.py index 996d6e0..0f2e41b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,4 @@ -"""Tests for sshq server /ask endpoint.""" +"""Tests for sshq server /ask and /analyze endpoints.""" from unittest.mock import MagicMock, patch import pytest @@ -23,6 +23,9 @@ def mock_genai_client(): yield mock +# --- /ask --- + + def test_ask_without_prompt_returns_400(client): r = client.post("/ask", json={}) assert r.status_code == 400 @@ -50,3 +53,44 @@ def test_ask_on_api_error_returns_500(client, mock_genai_client): assert r.status_code == 500 assert "error" in r.json assert "API error" in r.json["error"] + + +# --- /analyze --- + + +def test_analyze_without_prompt_or_content_returns_400(client): + r = client.post("/analyze", json={}) + assert r.status_code == 400 + assert "prompt" in r.json["error"] and "content" in r.json["error"] + + r = client.post("/analyze", json={"prompt": "explain"}) + assert r.status_code == 400 + + r = client.post("/analyze", json={"content": "some log"}) + assert r.status_code == 400 + + +def test_analyze_with_prompt_and_content_returns_analysis(client, mock_genai_client): + mock_response = MagicMock() + mock_response.text = "I see 2 failures in the log." + mock_genai_client.models.generate_content.return_value = mock_response + + r = client.post( + "/analyze", + json={"prompt": "any failures?", "content": "ERROR: disk full\nERROR: timeout"}, + ) + assert r.status_code == 200 + assert r.json == {"analysis": "I see 2 failures in the log."} + mock_genai_client.models.generate_content.assert_called_once() + + +def test_analyze_on_api_error_returns_500(client, mock_genai_client): + mock_genai_client.models.generate_content.side_effect = RuntimeError("API error") + + r = client.post( + "/analyze", + json={"prompt": "explain", "content": "log line"}, + ) + assert r.status_code == 500 + assert "error" in r.json + assert "API error" in r.json["error"]