diff --git a/qllm/modeling/base.py b/qllm/modeling/base.py index 0e1ef49..5af1386 100644 --- a/qllm/modeling/base.py +++ b/qllm/modeling/base.py @@ -128,7 +128,12 @@ def cached_file_func_in_thread(task_func, *args, **kwargs): return executor.submit(task_func, *args, **kwargs) transformers.utils.hub.cached_file = functools.partial(cached_file_func_in_thread, transformers.utils.hub.cached_file) result = task_func_shard(*args, **kwargs) - result_0 = [future.result() for future in result[0]] + result_0 = [] + for item in result[0]: + if isinstance(item, str): + result_0.append(item) + else: + result_0.append(item.result()) return result_0, result[1] @@ -229,7 +234,7 @@ def from_pretrained( ) llm = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, - torch_dtype=torch_dtype, + dtype=torch_dtype, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation, # device_map="auto",