Backend Inference Refactor#35
Conversation
Suryanshg
left a comment
There was a problem hiding this comment.
General comments for now, will be doing another round of review soon
| # Use the last message as the user message (it should always be a user message) | ||
| user_query = messages[-1]['content'] | ||
| retrieved_docs: List[AniZenithVectorSearchResult] = self.db_client.perform_vector_search(user_query, limit=VECTOR_SEARCH_LIMIT) | ||
| print(f"Retrieved Docs: ({len(retrieved_docs)})") |
There was a problem hiding this comment.
nit:
print(f"Retrieved Docs: ({len(retrieved_docs)})") --> print(f"Retrieved ({len(retrieved_docs)}) relevant docs")
| print(f"Retrieved Docs: ({len(retrieved_docs)})") | ||
|
|
||
| # 2) Rerank results using the reranker based on document info and user query | ||
| with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranker").time(): |
There was a problem hiding this comment.
ultranit: use stage="reranking" as its a verb?
|
|
||
| # 2) Rerank results using the reranker based on document info and user query | ||
| with CHATBOT_PIPELINE_LATENCY_SUMMARY.labels(model=current_model.get_name(), stage="reranker").time(): | ||
| recommended_docs: List[AniZenithVectorSearchResult] = self.reranker.rerank(user_query, retrieved_docs, limit=RERANK_LIMIT) |
There was a problem hiding this comment.
rename: recommended_docs --> reranked_docs
|
|
||
| # Add base system prompt | ||
| lines.append(SYSTEM_PROMPT) | ||
| lines.append("Here are the recommendation system's top shows:\n") |
There was a problem hiding this comment.
maybe you can just add this directly to the System Prompt, instead of manually doing it here?
There was a problem hiding this comment.
I was thinking that, but we also might want to change how the system prompt is (for example adding more context strings), so I think it is good practice to keep it separated
There was a problem hiding this comment.
But it should be added to some config system yes
| from fastapi.middleware.cors import CORSMiddleware | ||
| import logging | ||
| from prometheus.prometheus_middleware import PrometheusMiddleware, prometheus_router | ||
| from dotenv import load_dotenv |
There was a problem hiding this comment.
Can you revert this, as you might've added it for testing locally?
The docker compose yaml is already injecting env variables using a specific frontend.env file
|
|
||
| def stream(self, messages: List[Dict[str, str]]): | ||
| self._usage_data = None | ||
| print("Starting tokenize") |
There was a problem hiding this comment.
Either remove this print statement, or add this to the stream() definition of HFInferenceClientModel
| thread.start() | ||
|
|
||
| # Accumulate usage | ||
| input_token_count = inputs['input_ids'].shape[-1] |
There was a problem hiding this comment.
If its not too much trouble, can you also add a small inline comment to depict the shape of inputs['input_ids']? Something like:
# inputs['input_ids'] has shape (x, y, z)
input_token_count = inputs['input_ids'].shape[-1]This improves readability of ML based code a lot (atleast for me)
|
|
||
| def generate(): | ||
| # Ensure no gradients | ||
| with torch.inference_mode(): |
There was a problem hiding this comment.
Whats the difference between this and with torch.no_grad():?
There was a problem hiding this comment.
I was looking at docs to see recommended ways to make inference not using pipeline faster. This was one recommendation, but I believe it is not actually doing anything differently yes.
|
|
||
| self._usage_data = None | ||
|
|
||
| def stream(self, messages: List[Dict[str, str]]): |
There was a problem hiding this comment.
I do not think this method is thread safe, as two concurrent requests can arrive anytime and the second request can overwrite self._usage_data (originally used for first request). You can have similar problems with self._thread_error variable.
I am thinking how to make it thread safe, but some basic ways are:
- Coding "pure functions" (if you don't know, you can look it up or we can discuss)
- Write non blocking code (which in some sense you already are doing it, but usage of threads in fast-api should be discouraged, as fast-api is async world and should use event loop based processing)
- Maybe just use regular variables instead of class variables and return the usage data value within a Tuple of something
There was a problem hiding this comment.
Yes it is not thread safe. There is a TODO in InferenceManager to add a blocking queue system. The idea is, these models only call one stream() job at once (or multiple if we have multiple models loaded in backend). I looked at some ways to do this, but it is not trivial, so we accept one request at a time for now
There was a problem hiding this comment.
TODO: After discussion, this program requires significant work and needs an additional PR
Suryanshg
left a comment
There was a problem hiding this comment.
Added some more comments related to concurrency issues identified with the code
| # Accumulate usage | ||
| input_token_count = inputs['input_ids'].shape[-1] | ||
| output_token_count = 0 | ||
| for text in streamer: |
There was a problem hiding this comment.
Is it guaranteed that the streamer always yields individual tokens, instead of decoded strings (can be multiple tokens at once). Or maybe it can yield empty strings or combine multiple tokens before yielding.
In all these cases, the token count won't exactly match the actual discrete outputs...
There was a problem hiding this comment.
As discussed, TODO: Add test case to prove in model test cases
| collected_result = "" | ||
| for result in chat_with_llm(messages=[{"role": "user", "content": TEST_USER_MESSAGE}], | ||
| use_local_model=use_local_model): | ||
| collected_result = result |
There was a problem hiding this comment.
Currently this does not call the local model, so test is invalid. Needs monkeypatch fix
Implements a new Object-Oriented and Scalable backend formation to allow for easier agentic framework implementation