From 0e3fee663d95090af2e4b83e0be961defb2c8b2e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 1 May 2024 15:03:13 -0400 Subject: [PATCH 1/3] adding benchmark prompts --- benchmark_prompts.txt | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 benchmark_prompts.txt diff --git a/benchmark_prompts.txt b/benchmark_prompts.txt new file mode 100644 index 0000000..d433f38 --- /dev/null +++ b/benchmark_prompts.txt @@ -0,0 +1,26 @@ +What is the capital of France? +Describe the process of photosynthesis in detail. +Write a short story about a magical talking cat. +How do airplanes fly? +Explain the concept of time dilation in Einstein's theory of relativity. +What are the main differences between living in a city and living in the countryside? +Write a haiku about the moon. +Discuss the impact of social media on modern society. +What is the chemical formula for water? +Describe your perfect day from start to finish. +How do you make a paper airplane? +Write a dialogue between two characters discussing the meaning of life. +What are the three branches of the United States government and their roles? +Describe the life cycle of a butterfly. +Write a persuasive paragraph on why everyone should learn to play a musical instrument. +What is the difference between a virus and a bacteria? +Create a recipe for a healthy breakfast smoothie. +Discuss the themes present in Shakespeare's play "Hamlet". +What are the five largest countries in the world by population? +Write a letter to your future self in 10 years. +How does the Internet work? +Describe the plot of your favorite movie. +What are the main causes of climate change? +Write a poem about the changing seasons. +Discuss the importance of regular exercise for maintaining good health. +What is the Pythagorean theorem, and how is it used in mathematics? \ No newline at end of file From c76c2eeebba77841d64296f6ed34fe6ed50f6aa6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 1 May 2024 15:08:04 -0400 Subject: [PATCH 2/3] adding basic benchmarking code --- run_benchmark.py | 112 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 run_benchmark.py diff --git a/run_benchmark.py b/run_benchmark.py new file mode 100644 index 0000000..510e3a3 --- /dev/null +++ b/run_benchmark.py @@ -0,0 +1,112 @@ +import sys + +sys.path.append("./") +import torch +from torch.nn import functional as F +from hqq.core.quantize import BaseQuantizeConfig +from huggingface_hub import snapshot_download +from IPython.display import clear_output +from tqdm.auto import trange +from transformers import AutoConfig, AutoTokenizer +from transformers.utils import logging as hf_logging + +from src.build_model import OffloadConfig, QuantConfig, build_model + +model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" +quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo" +state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo" +benchmark_prompts = "benchmark_prompts.txt" + +def read_prompt(file_path): + with open(file_path, "r") as f: + prompts = f.readlines() + return prompts + +# print(read_prompt(benchmark_prompts)) + + +# config = AutoConfig.from_pretrained(quantized_model_name) + +# device = torch.device("cuda:0") + +# ##### Change this to 5 if you have only 12 GB of GPU VRAM ##### +# offload_per_layer = 4 +# # offload_per_layer = 5 +# ############################################################### + +# num_experts = config.num_local_experts + +# offload_config = OffloadConfig( +# main_size=config.num_hidden_layers * (num_experts - offload_per_layer), +# offload_size=config.num_hidden_layers * offload_per_layer, +# buffer_size=4, +# offload_per_layer=offload_per_layer, +# ) + + +# attn_config = BaseQuantizeConfig( +# nbits=4, +# group_size=64, +# quant_zero=True, +# quant_scale=True, +# ) +# attn_config["scale_quant_params"]["group_size"] = 256 + + +# ffn_config = BaseQuantizeConfig( +# nbits=2, +# group_size=16, +# quant_zero=True, +# quant_scale=True, +# ) +# quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config) + + +# model = build_model( +# device=device, +# quant_config=quant_config, +# offload_config=offload_config, +# state_path=state_path, +# ) + +# from transformers import TextStreamer + + +# tokenizer = AutoTokenizer.from_pretrained(model_name) +# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) +# past_key_values = None +# sequence = None + +# seq_len = 0 +# while True: +# print("User: ", end="") +# user_input = input() +# print("\n") + +# user_entry = dict(role="user", content=user_input) +# input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device) + +# if past_key_values is None: +# attention_mask = torch.ones_like(input_ids) +# else: +# seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1) +# attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device) + +# print("Mixtral: ", end="") +# result = model.generate( +# input_ids=input_ids, +# attention_mask=attention_mask, +# past_key_values=past_key_values, +# streamer=streamer, +# do_sample=True, +# temperature=0.9, +# top_p=0.9, +# max_new_tokens=512, +# pad_token_id=tokenizer.eos_token_id, +# return_dict_in_generate=True, +# output_hidden_states=True, +# ) +# print("\n") + +# sequence = result["sequences"] +# past_key_values = result["past_key_values"] From 9ca87378afbefe02b57bb9826efb12a599e33b9b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 1 May 2024 16:53:17 -0400 Subject: [PATCH 3/3] finishing benchmarking --- run_benchmark.py | 228 +++++++++++++++--------- src/build_model.py | 18 +- src/custom_layers.py | 101 +---------- src/expert_cache.py | 111 ++++++------ track_results/data/Initial_attempt.json | 1 + track_results/logs/Initial_attempt.txt | 42 +++++ 6 files changed, 254 insertions(+), 247 deletions(-) create mode 100644 track_results/data/Initial_attempt.json create mode 100644 track_results/logs/Initial_attempt.txt diff --git a/run_benchmark.py b/run_benchmark.py index 510e3a3..de71d07 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -9,12 +9,13 @@ from tqdm.auto import trange from transformers import AutoConfig, AutoTokenizer from transformers.utils import logging as hf_logging - +import time +from tqdm import tqdm from src.build_model import OffloadConfig, QuantConfig, build_model model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo" -state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo" +state_path = "/home/amangupt/random/mixtral-offloading/Mixtral-8x7B-Instruct-v0.1-offloading-demo" benchmark_prompts = "benchmark_prompts.txt" def read_prompt(file_path): @@ -22,91 +23,144 @@ def read_prompt(file_path): prompts = f.readlines() return prompts +all_prompts = read_prompt(benchmark_prompts) +all_prompts = all_prompts[:2] # print(read_prompt(benchmark_prompts)) -# config = AutoConfig.from_pretrained(quantized_model_name) - -# device = torch.device("cuda:0") - -# ##### Change this to 5 if you have only 12 GB of GPU VRAM ##### -# offload_per_layer = 4 -# # offload_per_layer = 5 -# ############################################################### - -# num_experts = config.num_local_experts - -# offload_config = OffloadConfig( -# main_size=config.num_hidden_layers * (num_experts - offload_per_layer), -# offload_size=config.num_hidden_layers * offload_per_layer, -# buffer_size=4, -# offload_per_layer=offload_per_layer, -# ) - - -# attn_config = BaseQuantizeConfig( -# nbits=4, -# group_size=64, -# quant_zero=True, -# quant_scale=True, -# ) -# attn_config["scale_quant_params"]["group_size"] = 256 - - -# ffn_config = BaseQuantizeConfig( -# nbits=2, -# group_size=16, -# quant_zero=True, -# quant_scale=True, -# ) -# quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config) - - -# model = build_model( -# device=device, -# quant_config=quant_config, -# offload_config=offload_config, -# state_path=state_path, -# ) - -# from transformers import TextStreamer - - -# tokenizer = AutoTokenizer.from_pretrained(model_name) -# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) -# past_key_values = None -# sequence = None - -# seq_len = 0 -# while True: -# print("User: ", end="") -# user_input = input() -# print("\n") - -# user_entry = dict(role="user", content=user_input) -# input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device) - -# if past_key_values is None: -# attention_mask = torch.ones_like(input_ids) -# else: -# seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1) -# attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device) - -# print("Mixtral: ", end="") -# result = model.generate( -# input_ids=input_ids, -# attention_mask=attention_mask, -# past_key_values=past_key_values, -# streamer=streamer, -# do_sample=True, -# temperature=0.9, -# top_p=0.9, -# max_new_tokens=512, -# pad_token_id=tokenizer.eos_token_id, -# return_dict_in_generate=True, -# output_hidden_states=True, -# ) -# print("\n") - -# sequence = result["sequences"] -# past_key_values = result["past_key_values"] +config = AutoConfig.from_pretrained(quantized_model_name) + +device = torch.device("cuda:0") + +##### Change this to 5 if you have only 12 GB of GPU VRAM ##### +offload_per_layer = 4 +# offload_per_layer = 5 +############################################################### + +num_experts = config.num_local_experts + +offload_config = OffloadConfig( + main_size=config.num_hidden_layers * (num_experts - offload_per_layer), + offload_size=config.num_hidden_layers * offload_per_layer, + buffer_size=4, + offload_per_layer=offload_per_layer, +) + + +attn_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + quant_zero=True, + quant_scale=True, +) +attn_config["scale_quant_params"]["group_size"] = 256 + + +ffn_config = BaseQuantizeConfig( + nbits=2, + group_size=16, + quant_zero=True, + quant_scale=True, +) +quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config) + + +model, expert_cache_obj = build_model( + device=device, + quant_config=quant_config, + offload_config=offload_config, + state_path=state_path, +) + +from transformers import TextStreamer + + +tokenizer = AutoTokenizer.from_pretrained(model_name) +streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) +past_key_values = None +sequence = None +total_time = [] +total_num_tokens = [] + +seq_len = 0 +for i in tqdm(range(len(all_prompts))): + start = time.time() + print("User: ", end="") + user_input = all_prompts[i] + print(user_input) + print("\n") + + user_entry = dict(role="user", content=user_input) + input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device) + + if past_key_values is None: + attention_mask = torch.ones_like(input_ids) + else: + seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1) + attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device) + + print("Mixtral: ", end="") + result = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + streamer=streamer, + do_sample=True, + temperature=0.9, + top_p=0.9, + max_new_tokens=100, + pad_token_id=tokenizer.eos_token_id, + return_dict_in_generate=True, + output_hidden_states=True, + ) + print("\n") + sequence = result["sequences"] + past_key_values = result["past_key_values"] + end = time.time() + total_time.append(end - start) + seq_len = sum([len(seq) for seq in sequence]) + total_num_tokens.append(seq_len) + +# CHANGE FILENAME HERE +filename = "Initial_attempt" + +# make a track_results/logs/{filename}.txt file + +log_file = open(f"track_results/logs/{filename}.txt", "w") +dump_data_file = open(f"track_results/data/{filename}.json", "w") + + +print("TIME BENCHMARKS", file=log_file) +print(f"Total time taken: {sum(total_time)} seconds", file=log_file) +print(f"Total number of tokens generated: {sum(total_num_tokens)}", file=log_file) +print(f"Average token per second: {sum(total_num_tokens)/sum(total_time)}", file=log_file) +print('\n\n\n', file=log_file) + +print("HIT RATE BENCHMARKS", file=log_file) +data_hits = {} + +for k in expert_cache_obj.group_infos: + data_hits[k] = expert_cache_obj.group_infos[k].expert_counts +# print(data_hits) +# print overall hit rate and hit rate per layer +overall_hits = 0 +overall_misses = 0 +for layer in data_hits: + tot_calls = 0 + tot_hits = 0 + # print(data_hits[layer]) + for exp in data_hits[layer]: + tot_calls += data_hits[layer][exp][0] + tot_hits += data_hits[layer][exp][1] + # print(tot_hits, tot_calls) + overall_hits += tot_hits + overall_misses += tot_calls - tot_hits + print(f"Layer {layer}: Hit rate = {tot_hits/tot_calls}", file=log_file) + +print(f"Overall hit rate = {overall_hits/(overall_hits + overall_misses)}", file=log_file) + + +# dump data_hits, total_time, total_num_tokens to a json file +import json +all_stats = {"data_hits": data_hits, "total_time": total_time, "total_num_tokens": total_num_tokens} +json.dump(all_stats, dump_data_file) \ No newline at end of file diff --git a/src/build_model.py b/src/build_model.py index 4306dc5..652f3e5 100644 --- a/src/build_model.py +++ b/src/build_model.py @@ -121,7 +121,7 @@ def get_default_ffn_quant_config(ffn_dim: int = 14336, hidden_dim: int = 4096): def make_empty_expert( - model_config: MixtralConfig, quant_config: QuantConfig, use_gpu: bool = True + model_config: MixtralConfig, quant_config: QuantConfig ) -> MixtralBLockSparseTop2MLP_HQQ: meta1, meta2 = quant_config.get_ffn_metas( model_config.hidden_size, model_config.intermediate_size @@ -131,7 +131,6 @@ def make_empty_expert( quant_config.ffn_config, meta1, meta2, - use_gpu ) @@ -141,7 +140,6 @@ def make_and_load_expert_wrapper( states_dir: str, expert_uid: tuple[int, int], device: torch.device, - use_gpu: bool = True, ) -> MixtralExpertWrapper: layer_idx, expert_idx = expert_uid @@ -151,7 +149,7 @@ def make_and_load_expert_wrapper( state_fpath = json.load(f)["weight_map"][f"{module_idx}.w1.W_q"] state_dict = load_file(os.path.join(states_dir, state_fpath), device=str(device)) - expert = make_empty_expert(config, quant_config, use_gpu) + expert = make_empty_expert(config, quant_config) expert.load_state_dict(state_dict, strict=True) return MixtralExpertWrapper(expert, device) @@ -175,11 +173,11 @@ def build_model( state_dict_00 = load_00_expert_state_dict(state_path, device) - def _make_module(use_gpu=True): + def _make_module(): config = AutoConfig.from_pretrained(model_name) - expert = make_empty_expert(config, quant_config, use_gpu) + expert = make_empty_expert(config, quant_config) expert.load_state_dict(state_dict_00) - return MixtralExpertWrapper(expert, device=device if use_gpu else 'cpu') + return MixtralExpertWrapper(expert, device=device) with device, with_default_dtype(torch.float16): model = MixtralForCausalLM( @@ -208,7 +206,6 @@ def _make_module(use_gpu=True): main_size=offload_config.main_size, offload_size=offload_config.offload_size, buffer_size=offload_config.buffer_size, - fiddler=True, ) for layer_idx in trange(model_config.num_hidden_layers, desc="Loading experts"): curr_layer = model.model.layers[layer_idx] @@ -227,8 +224,7 @@ def _make_module(use_gpu=True): quant_config=quant_config, states_dir=state_path, expert_uid=(layer_idx, expert_idx), - device='cpu' if do_offload else device, - use_gpu=not do_offload, + device=device, ) expert_cache.add_expert( @@ -242,4 +238,4 @@ def _make_module(use_gpu=True): torch.cuda.synchronize(device) torch.cuda.empty_cache() - return model + return model, expert_cache diff --git a/src/custom_layers.py b/src/custom_layers.py index f8b2724..cfc3884 100644 --- a/src/custom_layers.py +++ b/src/custom_layers.py @@ -14,7 +14,7 @@ class HQQLinearTritonSavable(HQQLinear): - def __init__(self, layer, quant_config, meta=None, use_gpu=True, **kwargs): + def __init__(self, layer, quant_config, meta=None, **kwargs): """ Example how to get meta: >>>> meta1 = HQQLinearSavable.get_hqq_meta((hidden_dim, ffn_dim), quant_config) @@ -23,7 +23,6 @@ def __init__(self, layer, quant_config, meta=None, use_gpu=True, **kwargs): assert quant_config['weight_quant_params']['nbits'] in [2, 3, 4] - self.use_gpu = use_gpu super().__init__(layer, quant_config, **kwargs) if not hasattr(self, 'meta'): @@ -33,54 +32,11 @@ def __init__(self, layer, quant_config, meta=None, use_gpu=True, **kwargs): self._register_state_dict_hook(self._add_to_state_dict_hook) self._register_load_state_dict_pre_hook(self._load_from_state_dict_hook) - def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params): - quant_scale = scale_quant_params is not None - quant_zero = zero_quant_params is not None - - #Quantize - W_q , meta = Quantizer.quantize(W, **weight_quant_params) - meta.update({'quant_scale':quant_scale, 'quant_zero':quant_zero}) - if(meta['quant_scale']): - meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale'] - if(meta['quant_zero']): - meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], **zero_quant_params); del meta['zero'] - - self.W_q = W_q - self.meta = meta - if self.use_gpu: - self.cuda() - else: - self.cpu() - self.ready = True + def quantize(self, *args, **kwargs): + super().quantize(*args, **kwargs) # repacking self.repack() - - def cuda(self, device_n=0): - if(self.in_gpu): return - self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta, device_n) - if(self.meta['quant_scale']): - self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale'], device_n) - if(self.meta['quant_zero']): - self.meta['zero_q'] , self.meta['meta_zero'] = Quantizer.cuda(self.meta['zero_q'], self.meta['meta_zero'], device_n) - - if(self.bias is not None): - self.bias = self.bias.half().cuda(device_n) - - self.in_gpu = True - - def cpu(self): - if(self.in_gpu==False): return - self.W_q, self.meta = Quantizer.cpu(self.W_q, self.meta) - if(self.meta['quant_scale']): - self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cpu(self.meta['scale_q'], self.meta['meta_scale']) - if(self.meta['quant_zero']): - self.meta['zero_q'] , self.meta['meta_zero'] = Quantizer.cpu(self.meta['zero_q'], self.meta['meta_zero']) - - if(self.bias is not None): - self.bias = self.bias.half().cpu() - - self.in_gpu = False def repack(self): if self.W_q.shape != self.meta['shape']: @@ -91,46 +47,11 @@ def repack(self): self.W_q = Quantizer.pack[self.meta['packing']](W_q) def forward(self, x): - if not self.use_gpu: - return self.forward_pytorch(x) return self.forward_triton(x) def set_backend(self, backend): pass - @torch.inference_mode() - def forward_pytorch(self, x): - assert self.ready, "model was not quantized" - assert self.meta['axis'] == 0 - - W_q, meta = self.W_q, self.meta - del_keys = [] - if(meta['quant_scale']): - meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') - if(meta['quant_zero']): - meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') - - ### FORWARD CPU ############ - W_q_p = Quantizer.unpack[meta['packing']](W_q).half() - W_q_p = W_q_p[:meta['shape'][0], ...] - W_q_p = W_q_p.reshape((meta['group_size'], -1)) - - if((meta['group_size'] is not None) and (meta['nbits']==3)): - W_q_p = W_q_p[:meta['group_size']] if (meta['axis']==0) else W_q_p[:,:meta['group_size']] - W_est = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape']) - - out = torch.matmul(x, W_est.t()) - if(self.bias!=None): out += self.bias - ############################ - - #Cleanup - for key in del_keys: - del meta[key] - del W_q_p - del W_est - - return out - @torch.inference_mode() def forward_triton(self, x): assert self.ready, "model was not quantized" @@ -144,7 +65,6 @@ def forward_triton(self, x): if 'quant_zero' in meta and meta['quant_zero']: meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') - ### FORWARD GPU ############ K = meta['shape'][1] N = meta['shape'][0] @@ -164,7 +84,6 @@ def forward_triton(self, x): meta['zero'].view(-1, K), bias=self.bias if hasattr(self, 'bias') else None, ) - ############################ #Cleanup for key in del_keys: @@ -322,12 +241,12 @@ def load_state_dict(self, *args, **kwargs): class MixtralBLockSparseTop2MLP_HQQ(nn.Module): - def __init__(self, config: MixtralConfig, quant_config: Dict[str, Any], meta1, meta2, use_gpu=True): + def __init__(self, config: MixtralConfig, quant_config: Dict[str, Any], meta1, meta2): super().__init__() - self.w1 = HQQLinearTritonSavable(None, quant_config, meta1, use_gpu) - self.w2 = HQQLinearTritonSavable(None, quant_config, meta2, use_gpu) - self.w3 = HQQLinearTritonSavable(None, quant_config, meta1, use_gpu) + self.w1 = HQQLinearTritonSavable(None, quant_config, meta1) + self.w2 = HQQLinearTritonSavable(None, quant_config, meta2) + self.w3 = HQQLinearTritonSavable(None, quant_config, meta1) self.act_fn = ACT2FN[config.hidden_act] @@ -386,12 +305,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - if self.experts.fiddler: - current_state = current_state.to(expert_layer.storage.device) - current_hidden_states = expert_layer(current_state).to(routing_weights.device) * routing_weights[top_x_list, idx_list, None] + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + return final_hidden_states, router_logits \ No newline at end of file diff --git a/src/expert_cache.py b/src/expert_cache.py index 56deee6..c4e400e 100644 --- a/src/expert_cache.py +++ b/src/expert_cache.py @@ -24,6 +24,14 @@ class EvictionGroupInfo: offloaded_infos: OrderedDict[ExpertUID, ExpertInfo] = field(default_factory=OrderedDict) hits: int = field(default=0) misses: int = field(default=0) + expert_counts = {} # [total, hits] + num_experts = 8 + expert_counts: Dict[int, List[int]] = field(default_factory=dict) + + def __post_init__(self): + for i in range(self.num_experts): + self.expert_counts[i] = [0, 0] + def add(self, info: ExpertInfo): infos_odict = self.offloaded_infos if info.offloaded else self.main_infos @@ -42,9 +50,11 @@ def swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo): self.offloaded_infos[info_to_evict.uid] = self.main_infos.pop(info_to_evict.uid) def mark_used(self, info: ExpertInfo): + self.expert_counts[info.uid[1]][0] += 1 if info.uid in self.main_infos: self.main_infos.move_to_end(info.uid, last=True) self.hits += 1 + self.expert_counts[info.uid[1]][1] += 1 elif info.uid in self.offloaded_infos: self.offloaded_infos.move_to_end(info.uid, last=True) self.misses += 1 @@ -53,29 +63,25 @@ def mark_used(self, info: ExpertInfo): class ExpertCache: - def __init__(self, make_module: callable, main_size: int, offload_size: int, buffer_size: int, fiddler: bool = False): + def __init__(self, make_module: callable, main_size: int, offload_size: int, buffer_size: int): """Dynamically loads an array of modules with identical hyperparameters""" self.module_type = self.module_size = self.device = None self.active = False - self.fiddler = fiddler self.registered_experts: Dict[ExpertUID, ExpertInfo] = dict() - self.main_modules = [self._check_module(make_module()) for _ in range(main_size)] + self.main_modules = [self._check_module(make_module()) for i in range(main_size)] self.main_infos: List[Optional[ExpertInfo]] = [None for _ in range(main_size)] assert self.module_size is not None + self.offloaded_storages = [ + torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)] self.offloaded_infos: List[Optional[ExpertInfo]] = [None for _ in range(offload_size)] - if self.fiddler: - self.offloaded_storages = [self._check_module(make_module(use_gpu=False)) for _ in range(offload_size)] - else: - self.offloaded_storages = [ - torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(offload_size)] - # temporary storage to shave off latency - self.device_expert_buffers = deque([self._check_module(make_module()) for _ in range(buffer_size)]) - self.offloaded_storage_buffers = deque([ - torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(buffer_size)]) + # temporary storage to shave off latency + self.device_expert_buffers = deque([self._check_module(make_module()) for _ in range(buffer_size)]) + self.offloaded_storage_buffers = deque([ + torch.UntypedStorage(self.module_size).pin_memory(self.device) for _ in range(buffer_size)]) self.group_infos: Dict[int, EvictionGroupInfo] = defaultdict(EvictionGroupInfo) def _check_module(self, module: MixtralExpertWrapper): @@ -87,7 +93,7 @@ def _check_module(self, module: MixtralExpertWrapper): else: assert isinstance(module, self.module_type) assert len(module.storage) == self.module_size - assert module.storage.device == self.device or self.fiddler + assert module.storage.device == self.device return module def add_expert(self, uid: ExpertUID, module: MixtralExpertWrapper, eviction_group: int = 0, @@ -114,17 +120,15 @@ def add_expert_storage(self, uid: ExpertUID, storage: torch.UntypedStorage, if offload is None or offload: # True or None for i in range(len(self.offloaded_storages)): if self.offloaded_infos[i] is None: - if self.fiddler: - self.offloaded_storages[i].storage.copy_(storage) - else: - self.offloaded_storages[i].copy_(storage) + self.offloaded_storages[i].copy_(storage) info = ExpertInfo(uid, eviction_group=eviction_group, offloaded=True, index=i) self.registered_experts[uid] = self.offloaded_infos[i] = info self.group_infos[eviction_group].add(info) return # done allocating; found an offloaded spot raise ValueError("Cache is full") - def load_experts(self, *uids: ExpertUID, unordered: bool = False) -> Iterator[Tuple[ExpertUID, MixtralExpertWrapper]]: + def load_experts( + self, *uids: ExpertUID, unordered: bool = False) -> Iterator[Tuple[ExpertUID, MixtralExpertWrapper]]: """ :example: >>> for uid, expert in expert_cache.load_experts(*list_of_uids, unordered=True): @@ -150,44 +154,37 @@ def load_experts(self, *uids: ExpertUID, unordered: bool = False) -> Iterator[Tu try: self.active = True - - if self.fiddler: - selected_experts = [self.main_modules[info.index] if not info.offloaded else self.offloaded_storages[info.index] for info in infos] - for uid, expert in zip(uids, selected_experts): - yield uid, expert - - else: - # save pre-loaded experts before they can be swapped - pre_loaded_infos = deque([info for info in infos if not info.offloaded]) - pre_loaded_experts = deque([self.main_modules[info.index] for info in pre_loaded_infos]) - - # begin loading experts into free buffers in background (via non-blocking copy) - infos_to_load = deque([info for info in infos if info.offloaded]) - infos_in_loading = deque([]) - experts_in_loading = deque([]) - window_size = min(len(self.device_expert_buffers) - 1, - len(eviction_group.main_infos), - len(infos_to_load)) - for _ in range(window_size): - info_to_load = infos_to_load.popleft() - infos_in_loading.append(info_to_load) - experts_in_loading.append( - self._swap(info_to_load, eviction_group.choose_expert_to_evict())) - - for info in infos: - if len(pre_loaded_infos) > 0 and info is pre_loaded_infos[0]: - pre_loaded_infos.popleft() - yield (info.uid, pre_loaded_experts.popleft()) - elif len(infos_in_loading) > 0 and info is infos_in_loading[0]: - infos_in_loading.popleft() - yield (info.uid, experts_in_loading.popleft()) - if len(infos_to_load) > 0: - info_to_load = infos_to_load.popleft() - infos_in_loading.append(info_to_load) - experts_in_loading.append( - self._swap(info_to_load, eviction_group.choose_expert_to_evict())) - else: - raise RuntimeError("internal error: caching algorithm failed") + # save pre-loaded experts before they can be swapped + pre_loaded_infos = deque([info for info in infos if not info.offloaded]) + pre_loaded_experts = deque([self.main_modules[info.index] for info in pre_loaded_infos]) + + # begin loading experts into free buffers in background (via non-blocking copy) + infos_to_load = deque([info for info in infos if info.offloaded]) + infos_in_loading = deque([]) + experts_in_loading = deque([]) + window_size = min(len(self.device_expert_buffers) - 1, + len(eviction_group.main_infos), + len(infos_to_load)) + for _ in range(window_size): + info_to_load = infos_to_load.popleft() + infos_in_loading.append(info_to_load) + experts_in_loading.append( + self._swap(info_to_load, eviction_group.choose_expert_to_evict())) + + for info in infos: + if len(pre_loaded_infos) > 0 and info is pre_loaded_infos[0]: + pre_loaded_infos.popleft() + yield (info.uid, pre_loaded_experts.popleft()) + elif len(infos_in_loading) > 0 and info is infos_in_loading[0]: + infos_in_loading.popleft() + yield (info.uid, experts_in_loading.popleft()) + if len(infos_to_load) > 0: + info_to_load = infos_to_load.popleft() + infos_in_loading.append(info_to_load) + experts_in_loading.append( + self._swap(info_to_load, eviction_group.choose_expert_to_evict())) + else: + raise RuntimeError("internal error: caching algorithm failed") finally: self.active = False @@ -211,4 +208,4 @@ def _swap(self, info_to_load: ExpertInfo, info_to_evict: ExpertInfo) -> nn.Modul info_to_evict.offloaded, info_to_load.offloaded = info_to_load.offloaded, info_to_evict.offloaded info_to_evict.index, info_to_load.index = info_to_load.index, info_to_evict.index self.group_infos[info_to_load.eviction_group].swap(info_to_load, info_to_evict) - return device_expert_buffer + return device_expert_buffer \ No newline at end of file diff --git a/track_results/data/Initial_attempt.json b/track_results/data/Initial_attempt.json new file mode 100644 index 0000000..5018653 --- /dev/null +++ b/track_results/data/Initial_attempt.json @@ -0,0 +1 @@ +{"data_hits": {"0": {"0": [46, 17], "1": [45, 21], "2": [56, 26], "3": [49, 25], "4": [54, 33], "5": [49, 24], "6": [49, 25], "7": [63, 41]}, "1": {"0": [56, 31], "1": [42, 24], "2": [32, 12], "3": [81, 62], "4": [56, 40], "5": [54, 30], "6": [49, 22], "7": [42, 24]}, "2": {"0": [43, 19], "1": [61, 36], "2": [53, 26], "3": [72, 51], "4": [41, 22], "5": [42, 21], "6": [54, 31], "7": [46, 22]}, "3": {"0": [44, 24], "1": [44, 24], "2": [52, 30], "3": [49, 29], "4": [58, 44], "5": [58, 38], "6": [60, 38], "7": [46, 24]}, "4": {"0": [45, 22], "1": [51, 26], "2": [54, 30], "3": [40, 15], "4": [74, 49], "5": [53, 25], "6": [45, 24], "7": [50, 26]}, "5": {"0": [77, 51], "1": [61, 37], "2": [38, 16], "3": [38, 13], "4": [42, 17], "5": [63, 44], "6": [48, 26], "7": [45, 26]}, "6": {"0": [70, 44], "1": [71, 50], "2": [36, 17], "3": [23, 9], "4": [44, 24], "5": [43, 24], "6": [59, 40], "7": [66, 43]}, "7": {"0": [57, 37], "1": [57, 35], "2": [47, 30], "3": [53, 34], "4": [47, 30], "5": [60, 41], "6": [54, 35], "7": [36, 21]}, "8": {"0": [65, 42], "1": [55, 37], "2": [22, 8], "3": [10, 3], "4": [113, 101], "5": [45, 23], "6": [75, 57], "7": [27, 13]}, "9": {"0": [61, 45], "1": [47, 31], "2": [26, 11], "3": [32, 19], "4": [92, 75], "5": [50, 29], "6": [40, 24], "7": [63, 44]}, "10": {"0": [67, 44], "1": [68, 43], "2": [66, 43], "3": [41, 17], "4": [49, 23], "5": [50, 35], "6": [18, 3], "7": [52, 31]}, "11": {"0": [89, 67], "1": [50, 26], "2": [56, 35], "3": [23, 14], "4": [30, 10], "5": [52, 29], "6": [44, 28], "7": [66, 42]}, "12": {"0": [51, 33], "1": [54, 34], "2": [54, 34], "3": [29, 13], "4": [51, 26], "5": [39, 26], "6": [66, 51], "7": [66, 48]}, "13": {"0": [52, 29], "1": [61, 40], "2": [61, 40], "3": [42, 22], "4": [55, 29], "5": [43, 26], "6": [67, 51], "7": [31, 14]}, "14": {"0": [55, 27], "1": [44, 21], "2": [35, 13], "3": [48, 24], "4": [67, 47], "5": [59, 37], "6": [57, 34], "7": [47, 22]}, "15": {"0": [60, 35], "1": [54, 33], "2": [50, 26], "3": [53, 31], "4": [47, 25], "5": [45, 20], "6": [44, 20], "7": [59, 34]}, "16": {"0": [58, 35], "1": [61, 38], "2": [40, 22], "3": [52, 32], "4": [22, 8], "5": [49, 24], "6": [40, 19], "7": [89, 69]}, "17": {"0": [60, 29], "1": [46, 22], "2": [57, 33], "3": [33, 19], "4": [40, 20], "5": [64, 39], "6": [68, 48], "7": [44, 22]}, "18": {"0": [38, 9], "1": [65, 41], "2": [61, 36], "3": [33, 13], "4": [33, 12], "5": [88, 65], "6": [48, 28], "7": [46, 21]}, "19": {"0": [46, 25], "1": [45, 26], "2": [78, 57], "3": [61, 37], "4": [53, 30], "5": [41, 22], "6": [52, 32], "7": [35, 16]}, "20": {"0": [26, 10], "1": [65, 46], "2": [84, 68], "3": [40, 18], "4": [24, 12], "5": [44, 25], "6": [74, 54], "7": [55, 34]}, "21": {"0": [51, 35], "1": [24, 8], "2": [54, 37], "3": [41, 26], "4": [67, 47], "5": [36, 22], "6": [80, 65], "7": [58, 42]}, "22": {"0": [47, 24], "1": [50, 33], "2": [74, 57], "3": [77, 59], "4": [54, 36], "5": [43, 23], "6": [39, 22], "7": [27, 12]}, "23": {"0": [29, 15], "1": [56, 28], "2": [28, 13], "3": [70, 50], "4": [85, 63], "5": [55, 36], "6": [54, 35], "7": [35, 18]}, "24": {"0": [61, 50], "1": [49, 34], "2": [44, 25], "3": [38, 23], "4": [40, 25], "5": [49, 33], "6": [66, 42], "7": [65, 50]}, "25": {"0": [54, 38], "1": [56, 35], "2": [70, 49], "3": [75, 57], "4": [39, 20], "5": [37, 20], "6": [61, 40], "7": [19, 7]}, "26": {"0": [69, 53], "1": [66, 43], "2": [58, 42], "3": [49, 26], "4": [54, 32], "5": [29, 18], "6": [30, 14], "7": [57, 37]}, "27": {"0": [61, 35], "1": [51, 29], "2": [42, 21], "3": [36, 14], "4": [60, 33], "5": [54, 32], "6": [58, 36], "7": [50, 29]}, "28": {"0": [38, 15], "1": [54, 30], "2": [66, 38], "3": [44, 24], "4": [56, 28], "5": [74, 49], "6": [41, 21], "7": [39, 21]}, "29": {"0": [50, 19], "1": [56, 27], "2": [54, 27], "3": [58, 32], "4": [55, 27], "5": [33, 10], "6": [55, 28], "7": [50, 29]}, "30": {"0": [33, 11], "1": [45, 17], "2": [55, 27], "3": [64, 36], "4": [64, 38], "5": [56, 32], "6": [46, 15], "7": [49, 31]}, "31": {"0": [55, 19], "1": [57, 31], "2": [33, 13], "3": [35, 12], "4": [63, 33], "5": [79, 48], "6": [48, 29], "7": [42, 18]}}, "total_time": [42.383941888809204, 44.428449630737305], "total_num_tokens": [116, 120]} \ No newline at end of file diff --git a/track_results/logs/Initial_attempt.txt b/track_results/logs/Initial_attempt.txt new file mode 100644 index 0000000..81e5818 --- /dev/null +++ b/track_results/logs/Initial_attempt.txt @@ -0,0 +1,42 @@ +TIME BENCHMARKS +Total time taken: 86.81239151954651 seconds +Total number of tokens generated: 236 +Average token per second: 2.7185059168294274 + + + + +HIT RATE BENCHMARKS +Layer 0: Hit rate = 0.5158150851581509 +Layer 1: Hit rate = 0.5946601941747572 +Layer 2: Hit rate = 0.5533980582524272 +Layer 3: Hit rate = 0.610705596107056 +Layer 4: Hit rate = 0.5266990291262136 +Layer 5: Hit rate = 0.558252427184466 +Layer 6: Hit rate = 0.6092233009708737 +Layer 7: Hit rate = 0.6399026763990268 +Layer 8: Hit rate = 0.6893203883495146 +Layer 9: Hit rate = 0.6763990267639902 +Layer 10: Hit rate = 0.5815085158150851 +Layer 11: Hit rate = 0.6121951219512195 +Layer 12: Hit rate = 0.6463414634146342 +Layer 13: Hit rate = 0.6092233009708737 +Layer 14: Hit rate = 0.5461165048543689 +Layer 15: Hit rate = 0.5436893203883495 +Layer 16: Hit rate = 0.6009732360097324 +Layer 17: Hit rate = 0.5631067961165048 +Layer 18: Hit rate = 0.5461165048543689 +Layer 19: Hit rate = 0.5961070559610706 +Layer 20: Hit rate = 0.6480582524271845 +Layer 21: Hit rate = 0.6861313868613139 +Layer 22: Hit rate = 0.6472019464720195 +Layer 23: Hit rate = 0.6262135922330098 +Layer 24: Hit rate = 0.6844660194174758 +Layer 25: Hit rate = 0.6472019464720195 +Layer 26: Hit rate = 0.6432038834951457 +Layer 27: Hit rate = 0.5558252427184466 +Layer 28: Hit rate = 0.5485436893203883 +Layer 29: Hit rate = 0.48418491484184917 +Layer 30: Hit rate = 0.5024271844660194 +Layer 31: Hit rate = 0.49271844660194175 +Overall hit rate = 0.5932872655478776