diff --git a/ChatQnA/chatqna.py b/ChatQnA/chatqna.py index 95318e9613..b90807a869 100644 --- a/ChatQnA/chatqna.py +++ b/ChatQnA/chatqna.py @@ -48,6 +48,8 @@ def generate_rag_prompt(question, documents): LLM_SERVER_HOST_IP = os.getenv("LLM_SERVER_HOST_IP", "0.0.0.0") LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 80)) LLM_MODEL = os.getenv("LLM_MODEL", "Intel/neural-chat-7b-v3-3") +LLM_PROMPT = os.getenv("LLM_PROMPT", None) +RAG = os.getenv("RAG", None) def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): @@ -63,7 +65,26 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **k # convert TGI/vLLM to unified OpenAI /v1/chat/completions format next_inputs = {} next_inputs["model"] = LLM_MODEL - next_inputs["messages"] = [{"role": "user", "content": inputs["inputs"]}] + if LLM_PROMPT is None: + if RAG == "enabled": + print(f"LLM_PROMPT is None, RAG == True, INPUTS = {inputs['inputs']} \n\n") + question_index = inputs["inputs"].find("### Question:") + cleaned_query = inputs["inputs"][:question_index + len("### Question:")] + print(f"LLM_PROMPT is None, RAG == True, CLEANED_QUERY: {cleaned_query} \n\n") + next_inputs["messages"] = [{"role": "user", "content": cleaned_query}] + else: + next_inputs["messages"] = [{"role": "user", "content": inputs["inputs"]}] + + else: + if RAG == "enabled": + print(f"LLM_PROMPT is not None, RAG == True, INPUTS = {inputs['inputs']} \n\n") + question_index = inputs["inputs"].find("### Question:") + cleaned_query = inputs["inputs"][:question_index + len("### Question:")] + print(f"LLM_PROMPT is not None, RAG == True, CLEANED_QUERY: {cleaned_query} \n\n") + next_inputs["messages"] = [{"role": "user", "content": cleaned_query}] + else: + next_inputs['messages'] = [{"role": "user", "content": LLM_PROMPT}] + next_inputs["max_tokens"] = llm_parameters_dict["max_tokens"] next_inputs["top_p"] = llm_parameters_dict["top_p"] next_inputs["stream"] = inputs["streaming"] @@ -155,29 +176,55 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di return next_data - def align_generator(self, gen, **kwargs): + def split_lines(line): + """ + Split line into individual `data:` segments if multiple `data:` sections exist. + """ + parts = line.split("data:") + return [f"data:{part.strip()}\n\n" for part in parts if part.strip()] # openai reaponse format # b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct","system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},"logprobs":null,"finish_reason":null}]}\n\n' for line in gen: line = line.decode("utf-8") - start = line.find("{") - end = line.rfind("}") + 1 - - json_str = line[start:end] - try: - # sometimes yield empty chunk, do a fallback here - json_data = json.loads(json_str) - if ( - json_data["choices"][0]["finish_reason"] != "eos_token" - and "content" in json_data["choices"][0]["delta"] - ): - yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" - except Exception as e: - yield f"data: {repr(json_str.encode('utf-8'))}\n\n" + + if line.count("data:") > 1: + split_data = split_lines(line) + + for part in split_data: + print("split_data-------------", part) + start = part.find("{") + end = part.rfind("}") + 1 + + json_str = part[start:end] + try: + # sometimes yield empty chunk, do a fallback here + json_data = json.loads(json_str) + if ( + json_data["choices"][0]["finish_reason"] != "eos_token" + and "content" in json_data["choices"][0]["delta"] + ): + yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" + except Exception as e: + yield f"data: {repr(json_str.encode('utf-8'))}\n\n" + else: + start = line.find("{") + end = line.rfind("}") + 1 + + json_str = line[start:end] + try: + # sometimes yield empty chunk, do a fallback here + json_data = json.loads(json_str) + if ( + json_data["choices"][0]["finish_reason"] != "eos_token" + and "content" in json_data["choices"][0]["delta"] + ): + yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" + except Exception as e: + yield f"data: {repr(json_str.encode('utf-8'))}\n\n" + yield "data: [DONE]\n\n" - class ChatQnAService: def __init__(self, host="0.0.0.0", port=8000): self.host = host