diff --git a/qllm/plugin/chatcli/generation.py b/qllm/plugin/chatcli/generation.py index 184659b..b23f163 100644 --- a/qllm/plugin/chatcli/generation.py +++ b/qllm/plugin/chatcli/generation.py @@ -10,7 +10,7 @@ def generate_stream(model, tokenizer, prompt: str, device, max_new_tokens: int, lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0) - past_kvs = transformers.DynamicCache() + past_kvs = None output_ids = list(inputs.input_ids) input_echo_len = len(output_ids) @@ -70,7 +70,7 @@ def generate(model, tokenizer, prompt: str, max_new_tokens:int, context_len: int lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0) - past_kvs = transformers.DynamicCache() + past_kvs = None output_ids = list(inputs.input_ids) input_echo_len = len(output_ids) diff --git a/qllm/plugin/chatcli/inference.py b/qllm/plugin/chatcli/inference.py index 07b264b..8406883 100644 --- a/qllm/plugin/chatcli/inference.py +++ b/qllm/plugin/chatcli/inference.py @@ -1,5 +1,7 @@ import time import torch +import transformers +from packaging.version import Version, InvalidVersion try: import fastchat #from fastchat.conversation import Conversation, SeparatorStyle @@ -28,9 +30,16 @@ def chat_loop( debug: bool = True, echo: bool = False, ): + model_type = str(type(model)).lower() + use_fastchat_v2 = False if _fastchat_available: + try: + use_fastchat_v2 = Version(transformers.__version__) < Version("4.3") and "llama" not in model_type + except InvalidVersion: + use_fastchat_v2 = False + + if use_fastchat_v2: return chat_loop_v2(model, tokenizer) - model_type = str(type(model)).lower() if "llama" not in model_type and hasattr(tokenizer, 'apply_chat_template'): return chat_loop_v3(model, tokenizer) assert "llama" in model_type, 'have you installed fschat? please run `pip install fschat` and try again.'