diff --git a/adl-search-service/app/main.py b/adl-search-service/app/main.py index 10a064b2..80645c9e 100644 --- a/adl-search-service/app/main.py +++ b/adl-search-service/app/main.py @@ -1,139 +1,102 @@ -# SPDX-FileCopyrightText: 2025 Deutsche Telekom AG and others -# -# SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Any, Dict -import re +from typing import List, Optional +import os import json -import yaml -from fastapi import FastAPI, UploadFile, File, Form, HTTPException + +from fastapi import FastAPI, HTTPException from pydantic import BaseModel +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams -app = FastAPI(title="ADL Validator API") - - -class SyntaxErrorDetail(BaseModel): - line: Optional[int] - message: str - - -class ValidationResult(BaseModel): - syntax_errors: List[SyntaxErrorDetail] - used_tools: List[str] - references: List[str] - language: Optional[str] - - -def try_parse_json(text: str) -> Any: - return json.loads(text) - - -def try_parse_yaml(text: str) -> Any: - return yaml.safe_load(text) - - -def find_used_tools(text: str) -> List[str]: - tools = set() - # common patterns: tools:, uses:, run_tool(...), call_tool(...), @toolName - for m in re.finditer(r"\b(?:tools|uses|use|tool)\b[:=]?\s*([A-Za-z0-9_.-]+)", text, re.IGNORECASE): - tools.add(m.group(1)) - for m in re.finditer(r"([A-Za-z_][A-Za-z0-9_]*)\s*\(", text): - name = m.group(1) - if name.lower().startswith(("run_","call_","invoke_","tool_")): - tools.add(name) - for m in re.finditer(r"@([A-Za-z_][A-Za-z0-9_-]*)", text): - tools.add(m.group(1)) - return sorted(tools) - - -def find_references(text: str) -> List[str]: - refs = set() - # URLs - for m in re.finditer(r"https?://[^\s'\"<>]+", text): - refs.add(m.group(0)) - # file paths (simple heuristic) - for m in re.finditer(r"(?:[A-Za-z]:)?[\\/][\w\-./\\]+\.[a-zA-Z0-9]+", text): - refs.add(m.group(0)) - # explicit ref: tokens - for m in re.finditer(r"\bref(?:erence)?s?\b[:=]?\s*([A-Za-z0-9_./:-]+)", text, re.IGNORECASE): - refs.add(m.group(1)) - return sorted(refs) - - -def find_syntax_issues_text(text: str) -> List[SyntaxErrorDetail]: - errors: List[SyntaxErrorDetail] = [] - # unbalanced braces/brackets/parens - pairs = {'(': ')', '{': '}', '[': ']'} - stack: List[Dict[str, Any]] = [] - for i, ch in enumerate(text): - if ch in pairs: - stack.append({'char': ch, 'pos': i}) - elif ch in pairs.values(): - if not stack: - errors.append(SyntaxErrorDetail(line=None, message=f"Unmatched closing '{ch}' at pos {i}")) - else: - last = stack.pop() - if pairs[last['char']] != ch: - errors.append(SyntaxErrorDetail(line=None, message=f"Mismatched '{last['char']}' with '{ch}' at pos {i}")) - for leftover in stack: - errors.append(SyntaxErrorDetail(line=None, message=f"Unclosed '{leftover['char']}' starting at pos {leftover['pos']}")) - - # unclosed quotes - for quote in ['"', "'"]: - if text.count(quote) % 2 != 0: - errors.append(SyntaxErrorDetail(line=None, message=f"Unclosed quote {quote}")) - - # indentation mix (tabs vs spaces) - has_tabs = any(line.startswith('\t') for line in text.splitlines()) - has_spaces_indented = any(re.match(r" {2,}", line) for line in text.splitlines()) - if has_tabs and has_spaces_indented: - errors.append(SyntaxErrorDetail(line=None, message="Mixed tabs and spaces for indentation")) - - return errors - - -@app.post('/validate', response_model=ValidationResult) -async def validate_adl(file: UploadFile = File(None), text: str = Form(None)) -> ValidationResult: - if file is None and (text is None or text.strip() == ''): - raise HTTPException(status_code=400, detail="Provide an ADL file upload or `text` form field") - - if file is not None: - content_bytes = await file.read() - try: - content = content_bytes.decode('utf-8') - except Exception: - content = content_bytes.decode('latin-1') - else: - content = text +QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) +COLLECTION_NAME = os.getenv("QDRANT_COLLECTION", "adl_collection") - syntax_errors: List[SyntaxErrorDetail] = [] - language: Optional[str] = None +app = FastAPI(title="ADL Search Service") + + +class ADL(BaseModel): + id: str + title: Optional[str] + content: str + metadata: Optional[dict] = None - # try JSON - try: - _ = try_parse_json(content) - language = 'json' - except Exception as e_json: - # try YAML - try: - _ = try_parse_yaml(content) - language = 'yaml' - except Exception as e_yaml: - language = 'adl-text' - # best-effort syntax heuristics - syntax_errors = find_syntax_issues_text(content) - used_tools = find_used_tools(content) - references = find_references(content) +class ConversationRequest(BaseModel): + conversation: str + top_k: Optional[int] = 5 - return ValidationResult( - syntax_errors=syntax_errors, - used_tools=used_tools, - references=references, - language=language, - ) +def get_model(): + return SentenceTransformer("all-MiniLM-L6-v2") -if __name__ == '__main__': + +def get_qdrant_client(): + return QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) + + +@app.on_event("startup") +def startup_event(): + global model, qdrant + model = get_model() + qdrant = get_qdrant_client() + # Ensure collection exists + try: + if COLLECTION_NAME not in [c.name for c in qdrant.get_collections().result]: + vector_size = model.get_sentence_embedding_dimension() + qdrant.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + except Exception: + # older qdrant-client may return different types, fallback to simple create + vector_size = model.get_sentence_embedding_dimension() + try: + qdrant.create_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + except Exception: + pass + + +@app.post("/index") +def index_adls(adls: List[ADL]): + if not adls: + raise HTTPException(status_code=400, detail="No ADLs provided") + texts = [a.content for a in adls] + ids = [a.id for a in adls] + embeddings = model.encode(texts, show_progress_bar=False).tolist() + points = [ + {"id": ids[i], "vector": embeddings[i], "payload": {"title": adls[i].title, "metadata": adls[i].metadata}} + for i in range(len(adls)) + ] + qdrant.upsert(collection_name=COLLECTION_NAME, points=points) + return {"indexed": len(points)} + + +@app.post("/query") +def query_adls(req: ConversationRequest): + if not req.conversation: + raise HTTPException(status_code=400, detail="Conversation text required") + q_emb = model.encode([req.conversation])[0].tolist() + search_result = qdrant.search(collection_name=COLLECTION_NAME, query_vector=q_emb, limit=req.top_k) + results = [] + for r in search_result: + results.append({ + "id": r.id, + "score": r.score, + "payload": r.payload, + }) + return {"results": results} + + +@app.get("/health") +def health(): + return {"status": "ok"} + + +if __name__ == "__main__": import uvicorn - uvicorn.run('main:app', host='127.0.0.1', port=8000, reload=True) + uvicorn.run(app, host="0.0.0.0", port=8000)