forked from Zyphra/Zamba2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhf_utils.py
More file actions
17 lines (12 loc) · 730 Bytes
/
hf_utils.py
File metadata and controls
17 lines (12 loc) · 730 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import json
import torch
from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
WEIGHTS_NAME = "Zamba2_2p7b_direct_from_pytorch.pt"
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=True, force_download=False)
return torch.load(resolved_archive_file, map_location=mapped_device)