From 7b6325ccdf6a3acc85b12ceadbd16209dfb30e68 Mon Sep 17 00:00:00 2001 From: Fabien Dupont Date: Thu, 10 Oct 2024 14:38:04 -0400 Subject: [PATCH] Update to Python 3.12 Now that PyTorch 2.4 supports Python 3.12, This change attempts updating the default Python version to 3.12. It also brings the update to PyTorch 2.4.1 and all its dependencies. Signed-off-by: Fabien Dupont --- .github/workflows/lint.yml | 6 +- .../dolomite/hf_models/__init__.py | 2 +- src/instructlab/dolomite/hf_models/config.py | 13 +- src/instructlab/dolomite/hf_models/enums.py | 1 + .../dolomite/hf_models/mixins/__init__.py | 7 +- .../hf_models/mixins/dense/__init__.py | 1 + .../dolomite/hf_models/mixins/dense/base.py | 197 ++++++++++---- .../dolomite/hf_models/mixins/dense/main.py | 53 ++-- .../hf_models/mixins/dense_TP/__init__.py | 1 + .../hf_models/mixins/dense_TP/base.py | 25 +- .../hf_models/mixins/dense_TP/main.py | 94 +++++-- .../dolomite/hf_models/mixins/moe/__init__.py | 7 +- .../dolomite/hf_models/mixins/moe/base.py | 27 +- .../dolomite/hf_models/mixins/moe/main.py | 34 ++- .../hf_models/mixins/moe_TP/__init__.py | 1 + .../dolomite/hf_models/mixins/moe_TP/base.py | 10 +- .../dolomite/hf_models/mixins/moe_TP/main.py | 46 ++-- .../hf_models/model_conversion/__init__.py | 15 +- .../hf_models/model_conversion/bigcode.py | 33 ++- .../hf_models/model_conversion/granite.py | 49 +++- .../hf_models/model_conversion/granitemoe.py | 178 +++++++++---- .../hf_models/model_conversion/llama.py | 248 +++++++++++++----- .../hf_models/modeling_utils/__init__.py | 1 + .../modeling_utils/activations/__init__.py | 2 + .../modeling_utils/activations/base.py | 4 +- .../modeling_utils/activations/glu.py | 3 +- .../modeling_utils/attention/__init__.py | 8 +- .../modeling_utils/attention/base.py | 74 ++++-- .../modeling_utils/attention/flash.py | 4 +- .../modeling_utils/attention/padding_free.py | 17 +- .../modeling_utils/attention/sdpa.py | 8 +- .../modeling_utils/attention/utils.py | 18 +- .../hf_models/modeling_utils/embedding.py | 1 + .../hf_models/modeling_utils/linear.py | 1 + .../modeling_utils/normalization/__init__.py | 11 +- .../normalization/layernorm/__init__.py | 11 +- .../normalization/layernorm/apex.py | 13 +- .../layernorm/apex_persistent.py | 9 +- .../normalization/rmsnorm/__init__.py | 18 +- .../normalization/rmsnorm/apex.py | 19 +- .../normalization/rmsnorm/base.py | 1 + .../normalization/rmsnorm/torchtitan.py | 33 ++- .../position_embedding/__init__.py | 1 + .../position_embedding/alibi.py | 30 ++- .../modeling_utils/position_embedding/rope.py | 66 +++-- .../dolomite/hf_models/models/__init__.py | 2 +- .../hf_models/models/gpt_dolomite/__init__.py | 1 + .../hf_models/models/gpt_dolomite/base.py | 1 + .../hf_models/models/gpt_dolomite/config.py | 1 + .../hf_models/models/gpt_dolomite/layer.py | 10 +- .../hf_models/models/gpt_dolomite/main.py | 1 + .../hf_models/models/gpt_dolomite/mlp.py | 19 +- .../dolomite/hf_models/register_hf.py | 30 ++- src/instructlab/dolomite/hf_models/utils.py | 18 +- src/instructlab/dolomite/utils/hf_hub.py | 10 +- src/instructlab/dolomite/utils/safetensors.py | 9 +- tox.ini | 2 +- 57 files changed, 1107 insertions(+), 398 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 82b5fef..2abe06b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,7 +27,7 @@ on: - '.github/**' env: - PYTHON_VERSION: 3.11 + PYTHON_VERSION: 3.12 jobs: lint: @@ -45,10 +45,10 @@ jobs: fetch-depth: 0 submodules: true - - name: Setup Python 3.11 + - name: Setup Python 3.12 uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: - python-version: 3.11 + python-version: 3.12 cache: pip cache-dependency-path: | **/pyproject.toml diff --git a/src/instructlab/dolomite/hf_models/__init__.py b/src/instructlab/dolomite/hf_models/__init__.py index 86cc443..ef3ffb8 100644 --- a/src/instructlab/dolomite/hf_models/__init__.py +++ b/src/instructlab/dolomite/hf_models/__init__.py @@ -2,9 +2,9 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .models.gpt_dolomite.config import GPTDolomiteConfig from .model_conversion import export_to_huggingface, import_from_huggingface from .models import GPTDolomiteForCausalLM, GPTDolomiteModel +from .models.gpt_dolomite.config import GPTDolomiteConfig from .register_hf import register_model_classes register_model_classes() diff --git a/src/instructlab/dolomite/hf_models/config.py b/src/instructlab/dolomite/hf_models/config.py index 538dc34..49a2e9b 100644 --- a/src/instructlab/dolomite/hf_models/config.py +++ b/src/instructlab/dolomite/hf_models/config.py @@ -1,5 +1,7 @@ +# Third Party from transformers import PretrainedConfig +# Local from .enums import AttentionHeadType, InitMethod, PositionEmbeddingType @@ -98,7 +100,9 @@ def __init__( if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values" + assert ( + self.num_key_value_heads == 1 + ), "MultiQueryAttention should have 1 head for keys and values" elif attention_head_type == AttentionHeadType.gqa: assert ( self.num_key_value_heads is not None @@ -108,4 +112,9 @@ def __init__( self.n_head % self.num_key_value_heads == 0 ), "GroupedQueryAttention should have more than 1 head for keys and values" - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) diff --git a/src/instructlab/dolomite/hf_models/enums.py b/src/instructlab/dolomite/hf_models/enums.py index 5055bcf..0ac10fa 100644 --- a/src/instructlab/dolomite/hf_models/enums.py +++ b/src/instructlab/dolomite/hf_models/enums.py @@ -1,3 +1,4 @@ +# Standard from enum import Enum diff --git a/src/instructlab/dolomite/hf_models/mixins/__init__.py b/src/instructlab/dolomite/hf_models/mixins/__init__.py index c4f9102..e899ea2 100644 --- a/src/instructlab/dolomite/hf_models/mixins/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/__init__.py @@ -1,4 +1,7 @@ +# Local from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin -#from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP + +# from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin -#from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP + +# from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py index 0ee5d10..b29b99f 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseModelMixin, PreTrainedModelMixin from .main import CausalLMModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/base.py b/src/instructlab/dolomite/hf_models/mixins/dense/base.py index 3298682..e133727 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/base.py @@ -1,14 +1,23 @@ +# Standard import warnings -import torch -import torch.nn as nn +# Third Party from transformers import DynamicCache, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast +import torch +import torch.nn as nn +# Local from ...config import CommonConfig from ...defaults import DEFAULT_NORMALIZATION_IMPLEMENTATION from ...enums import AttentionHeadType, PositionEmbeddingType -from ...modeling_utils import Alibi, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function +from ...modeling_utils import ( + Alibi, + ParameterizedEmbedding, + RoPE, + YaRNScaledRoPE, + get_normalization_function, +) from ...utils import convert_padding_free_lists_to_tensors, divide_if_divisible @@ -39,13 +48,19 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> None: self.attention_implementation = self.config._attn_implementation self._use_eager_attention = self.attention_implementation == "eager" self._use_sdpa = self.attention_implementation == "sdpa" - self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2" - self._use_padding_free_transformer = kwargs.get("use_padding_free_transformer", False) + self._use_flash_attention_2 = ( + self.attention_implementation == "flash_attention_2" + ) + self._use_padding_free_transformer = kwargs.get( + "use_padding_free_transformer", False + ) self._tied_word_embeddings = config.tie_word_embeddings if self._use_padding_free_transformer: - assert self._use_flash_attention_2, "padding free transformer only works with flash attention" + assert ( + self._use_flash_attention_2 + ), "padding free transformer only works with flash attention" def _init_weights(self, module: nn.Module) -> None: if hasattr(module, "reset_parameters"): @@ -74,28 +89,43 @@ def prepare_inputs_for_model( ) assert cu_seqlens is None, error_message.format(variable="cu_seqlens") assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format(variable="attention_mask") - - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( - convert_padding_free_lists_to_tensors( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - device=torch.cuda.current_device(), - ) + assert attention_mask is None, error_message.format( + variable="attention_mask" + ) + + ( + input_ids, + position_ids, + token_type_ids, + labels, + cu_seqlens, + max_seqlen, + ) = convert_padding_free_lists_to_tensors( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + device=torch.cuda.current_device(), ) else: assert ( cu_seqlens is not None ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" + assert ( + position_ids is not None + ), "max_seqlen needs to be specified when specifying cu_seqlens" + assert ( + max_seqlen is not None + ), "max_seqlen needs to be specified when specifying cu_seqlens" + assert ( + attention_mask is None + ), "attention_mask should not be passed when specifying cu_seqlens" if use_cache or past_key_values is not None: - raise NotImplementedError("KV caching is not supported with padding_free transformer") + raise NotImplementedError( + "KV caching is not supported with padding_free transformer" + ) assert not output_attentions @@ -128,9 +158,13 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})", ) - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, self.embed_dim, std=self.initializer_range + ) - self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.drop = ( + nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + ) self.h = nn.ModuleList( [ self.layer_class( @@ -150,7 +184,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: normalization_implementation=self.normalization_implementation, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -206,7 +242,9 @@ def forward( # attention_mask -> (batch_size, 1, query_length, key_length) # ========================================================================================== - past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + past_key_values = ( + DynamicCache() if use_cache and past_key_values is None else past_key_values + ) all_hidden_states = () if output_hidden_states else None for block in self.h: if output_hidden_states: @@ -234,7 +272,12 @@ def forward( ) def _get_position_ids( - self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device + self, + attention_mask: torch.Tensor, + past_length: int, + query_length: int, + key_length: int, + device: torch.device, ) -> torch.Tensor: if attention_mask is not None and len(attention_mask.shape) == 2: # create position_ids on the fly for batch generation @@ -243,7 +286,9 @@ def _get_position_ids( if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, key_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids @@ -277,7 +322,11 @@ def _get_alibi_bias( return alibi_bias def _get_rope_cos_sin( - self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device + self, + key_length: int, + position_ids: torch.Tensor, + dtype: torch.dtype, + device: torch.device, ) -> torch.Tensor: if self.position_embedding_type == PositionEmbeddingType.rope: cos, sin = self.rope(key_length, dtype=dtype, device=device) @@ -301,7 +350,9 @@ def _prepare_causal_attention_mask( if query_length > 1: # (query_length, key_length) - causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) + causal_mask = torch.empty( + (query_length, key_length), dtype=torch.bool, device=device + ) causal_mask[:, past_length:] = torch.tril( torch.ones(query_length, query_length, dtype=torch.bool, device=device) ) @@ -321,10 +372,18 @@ def _prepare_causal_attention_mask( else: if attention_mask is None: # (batch_size, query_length, key_length) - causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) + causal_mask = torch.ones( + batch_size, + query_length, + key_length, + dtype=torch.bool, + device=device, + ) else: # (batch_size, query_length, key_length) - causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) + causal_mask = attention_mask.unsqueeze(1).to( + dtype=torch.bool, device=device + ) # ========================================================================================== # attention_mask -> (batch_size, query_length, key_length) @@ -387,14 +446,20 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) if use_cache is None: - use_cache = False if self._use_padding_free_transformer else self.config.use_cache + use_cache = ( + False if self._use_padding_free_transformer else self.config.use_cache + ) if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() @@ -425,7 +490,10 @@ def _prepare_a_bunch_of_stuff( else: if self.position_embedding_type == PositionEmbeddingType.alibi: if position_ids is not None: - warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning) + warnings.warn( + "`position_ids` have no functionality with Alibi.", + FutureWarning, + ) if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) @@ -447,12 +515,16 @@ def _prepare_a_bunch_of_stuff( if self._use_padding_free_transformer: key_length = max_seqlen.item() else: - past_length = 0 if past_key_values is None else past_key_values.get_seq_length() + past_length = ( + 0 if past_key_values is None else past_key_values.get_seq_length() + ) query_length = input_shape[-1] key_length = past_length + query_length if position_ids is None: - position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device) + position_ids = self._get_position_ids( + attention_mask, past_length, query_length, key_length, device + ) # ========================================================================================== # padding_free: @@ -465,7 +537,9 @@ def _prepare_a_bunch_of_stuff( # position_ids -> (batch_size, query_length) # ========================================================================================== - hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids) + hidden_states = self._get_initial_hidden_state( + input_ids, inputs_embeds, position_ids, token_type_ids + ) # ========================================================================================== # padding_free: @@ -475,7 +549,12 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== alibi_bias = self._get_alibi_bias( - attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype + attention_mask, + batch_size, + query_length, + key_length, + device, + hidden_states.dtype, ) # ========================================================================================== @@ -483,7 +562,10 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== rope_cos_sin = self._get_rope_cos_sin( - key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device + key_length, + position_ids, + dtype=hidden_states.dtype, + device=hidden_states.device, ) # ========================================================================================== @@ -494,7 +576,13 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== attention_mask = self._get_maybe_causal_mask( - attention_mask, alibi_bias, batch_size, query_length, key_length, hidden_states.dtype, device + attention_mask, + alibi_bias, + batch_size, + query_length, + key_length, + hidden_states.dtype, + device, ) return ( @@ -511,9 +599,13 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range) + self.wpe = ParameterizedEmbedding( + max_position_embeddings, self.embed_dim, std=self.initializer_range + ) elif self.position_embedding_type == PositionEmbeddingType.alibi: - assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention" + assert ( + not self._use_flash_attention_2 + ), "alibi is not implemented with FlashAttention" self.alibi = Alibi(self.num_heads) elif self.position_embedding_type == PositionEmbeddingType.rope: @@ -529,7 +621,9 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings=max_position_embeddings, base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], - original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], ) elif self.position_embedding_type == PositionEmbeddingType.nope: pass @@ -538,8 +632,14 @@ def _setup_positional_encoding(self) -> None: def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + if ( + self.mask_value is None + or self.mask_value.dtype != dtype + or self.mask_value.device != device + ): + self.mask_value = torch.full( + [], torch.finfo(dtype).min, dtype=dtype, device=device + ) return self.mask_value def _get_maybe_causal_mask( @@ -568,7 +668,10 @@ def _get_maybe_causal_mask( # this is needed to prevent NaN since SDPA # see issue: https://github.com/pytorch/pytorch/issues/110213 attention_mask = attention_mask * ~torch.all( - attention_mask == self._get_mask_value(attention_mask.device, dtype), dim=-1, keepdim=True + attention_mask + == self._get_mask_value(attention_mask.device, dtype), + dim=-1, + keepdim=True, ) elif self._use_eager_attention: attention_mask = self._prepare_causal_attention_mask( diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/main.py b/src/instructlab/dolomite/hf_models/mixins/dense/main.py index b03b9ed..603ca53 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/main.py @@ -1,8 +1,13 @@ +# Third Party +from transformers import DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) import torch import torch.nn.functional as F -from transformers import DynamicCache -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +# Local from ...config import CommonConfig from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear from .base import PreTrainedModelMixin @@ -21,7 +26,10 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: if not self._tied_word_embeddings: self.lm_head = ParameterizedLinear( - config.n_embd, config.vocab_size, bias=False, std=config.initializer_range + config.n_embd, + config.vocab_size, + bias=False, + std=config.initializer_range, ) self.m_width = config.m_width @@ -112,18 +120,20 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> tuple | CausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) # ========================================================================================== @@ -155,7 +165,9 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) return CausalLMOutputWithPast( loss=loss, @@ -173,7 +185,10 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: ) def get_autoregressive_language_modeling_loss( - self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + self, + lm_logits: torch.Tensor, + labels: torch.Tensor | None, + cu_seqlens: torch.Tensor, ) -> torch.Tensor: if labels is None: return None @@ -193,6 +208,8 @@ def get_autoregressive_language_modeling_loss( if self.upcast_logits_for_loss: shift_logits = shift_logits.float() - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) return loss diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py index cbbb640..3adca7c 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseModelMixin_TP, PreTrainedModelMixin_TP from .main import CausalLMModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py index 801bd72..ca41143 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py @@ -1,16 +1,25 @@ +# Third Party import torch.nn as nn +# Local from ....utils import ProcessGroupManager from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import RoPE, YaRNScaledRoPE -from ...modeling_utils_TP import Alibi_TP, Dropout_TP, Embedding_TP, get_normalization_function_TP +from ...modeling_utils_TP import ( + Alibi_TP, + Dropout_TP, + Embedding_TP, + get_normalization_function_TP, +) from ..dense import BaseModelMixin, PreTrainedModelMixin class PreTrainedModelMixin_TP(PreTrainedModelMixin): def __init__(self, config: CommonConfig, *args, **kwargs): - self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.tensor_parallel_word_embeddings = kwargs.get( + "tensor_parallel_word_embeddings", False + ) self.sequence_parallel = kwargs.get("sequence_parallel", False) super().__init__(config, *args, **kwargs) @@ -67,7 +76,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: sequence_parallel=self.sequence_parallel, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -90,7 +101,9 @@ def _setup_positional_encoding(self) -> None: elif self.position_embedding_type == PositionEmbeddingType.rope: if self.config.rope_scaling is None: self.rope = RoPE( - self.head_dim, max_position_embeddings=max_position_embeddings, base=self.config.rope_theta + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, ) else: self.rope = YaRNScaledRoPE( @@ -98,7 +111,9 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings=max_position_embeddings, base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], - original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], ) else: raise NotImplementedError() diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py index 4505921..cc8d019 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py @@ -1,14 +1,21 @@ +# Future from __future__ import annotations +# Standard from contextlib import nullcontext -import torch -import torch.nn.functional as F +# Third Party from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.tensor.parallel import loss_parallel from transformers import DynamicCache -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +import torch +import torch.nn.functional as F +# Local from ....utils import ProcessGroupManager, SafeTensorsWeightsManager from ...config import CommonConfig from ...enums import PositionEmbeddingType @@ -58,18 +65,20 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> tuple | CausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) transformer_outputs: BaseModelOutputWithPast = self.transformer( @@ -90,15 +99,21 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) if output_parallel_lm_logits: assert self.tensor_parallel_word_embeddings else: if self.tensor_parallel_word_embeddings: # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + lm_logits = tensor_to_dtensor( + lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1) + ) + lm_logits = dtensor_to_tensor( + lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate() + ) return CausalLMOutputWithPast( loss=loss, @@ -123,7 +138,10 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: ) def get_autoregressive_language_modeling_loss( - self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + self, + lm_logits: torch.Tensor, + labels: torch.Tensor | None, + cu_seqlens: torch.Tensor, ) -> torch.Tensor: if labels is None: return None @@ -143,16 +161,24 @@ def get_autoregressive_language_modeling_loss( shift_logits = tensor_to_dtensor( shift_logits, device_mesh=self.tp_mesh, - current_placement=Shard(-1) if self.tensor_parallel_word_embeddings else Replicate(), + current_placement=Shard(-1) + if self.tensor_parallel_word_embeddings + else Replicate(), + ) + shift_labels = tensor_to_dtensor( + shift_labels, device_mesh=self.tp_mesh, current_placement=Replicate() ) - shift_labels = tensor_to_dtensor(shift_labels, device_mesh=self.tp_mesh, current_placement=Replicate()) if self.upcast_logits_for_loss: shift_logits = shift_logits.float() - loss_context = loss_parallel if self.tensor_parallel_word_embeddings else nullcontext + loss_context = ( + loss_parallel if self.tensor_parallel_word_embeddings else nullcontext + ) with loss_context(): - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) return loss @@ -164,23 +190,35 @@ def from_pretrained( tensor_parallel_word_embeddings: bool = False, **kwargs, ) -> CausalLMModelMixin_TP: - config: CommonConfig = cls.config_class.from_pretrained(pretrained_model_name_or_path) + config: CommonConfig = cls.config_class.from_pretrained( + pretrained_model_name_or_path + ) # use dummy tensors to avoid initializing model here with torch.device("meta"): # try sharding vocab matrices if really struggling for memory - model = cls._from_config(config, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, **kwargs) + model = cls._from_config( + config, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + **kwargs, + ) model = model.to(dtype=torch_dtype) # copy to device without copying storage model = model.to_empty(device=torch.cuda.current_device()) - model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) + model.load_from_safetensors_weights_manager( + SafeTensorsWeightsManager(pretrained_model_name_or_path) + ) return model - def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: + def load_from_safetensors_weights_manager( + self, safetensors_weights_manager: SafeTensorsWeightsManager + ) -> None: with torch.device(torch.cuda.current_device()): - position_embedding_type = PositionEmbeddingType(self.config.position_embedding_type) + position_embedding_type = PositionEmbeddingType( + self.config.position_embedding_type + ) if position_embedding_type == PositionEmbeddingType.alibi: self.transformer.alibi.reset_parameters() diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py index 12b6465..c247564 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py @@ -1,2 +1,7 @@ -from .base import BaseMoEModelMixin, MoeModelOutputWithPastAndAuxLoss, PreTrainedMoEModelMixin +# Local +from .base import ( + BaseMoEModelMixin, + MoeModelOutputWithPastAndAuxLoss, + PreTrainedMoEModelMixin, +) from .main import CausalLMMoEModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/base.py b/src/instructlab/dolomite/hf_models/mixins/moe/base.py index 54ed982..6f7166c 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/base.py @@ -1,10 +1,13 @@ +# Standard from dataclasses import dataclass -import torch -import torch.nn as nn +# Third Party from transformers import DynamicCache from transformers.modeling_outputs import MoeModelOutputWithPast +import torch +import torch.nn as nn +# Local from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import ParameterizedEmbedding, get_normalization_function @@ -39,9 +42,13 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.head_dim = self.embed_dim // self.num_heads - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, self.embed_dim, std=self.initializer_range + ) - self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.drop = ( + nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + ) self.h = nn.ModuleList( [ self.layer_class( @@ -62,7 +69,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: normalization_implementation=self.normalization_implementation, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -116,7 +125,9 @@ def forward( # attention_mask -> (batch_size, 1, query_length, key_length) # ========================================================================================== - past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + past_key_values = ( + DynamicCache() if use_cache and past_key_values is None else past_key_values + ) all_hidden_states = () if output_hidden_states else None all_router_logits = () if output_router_logits else None total_aux_loss = 0 @@ -188,7 +199,9 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits ) return super()._prepare_a_bunch_of_stuff( diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/main.py b/src/instructlab/dolomite/hf_models/mixins/moe/main.py index 89e9632..1138711 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/main.py @@ -1,7 +1,9 @@ -import torch +# Third Party from transformers import DynamicCache from transformers.modeling_outputs import MoeCausalLMOutputWithPast +import torch +# Local from ...config import CommonConfig from ..dense import CausalLMModelMixin from .base import MoeModelOutputWithPastAndAuxLoss @@ -32,18 +34,20 @@ def forward( max_seqlen: torch.Tensor | None = None, output_router_logits: bool | None = None, ) -> tuple | MoeCausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) # ========================================================================================== @@ -76,7 +80,9 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + lm_loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) aux_loss = transformer_outputs.aux_loss if lm_loss is None: diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py index e4e90ab..1250111 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP from .main import CausalLMMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py index 55b09de..55749ff 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py @@ -1,5 +1,7 @@ +# Third Party import torch.nn as nn +# Local from ....utils import ProcessGroupManager from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType @@ -10,7 +12,9 @@ class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP): def __init__(self, config: CommonConfig, *args, **kwargs): - self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.tensor_parallel_word_embeddings = kwargs.get( + "tensor_parallel_word_embeddings", False + ) self.sequence_parallel = kwargs.get("sequence_parallel", False) super().__init__(config, *args, **kwargs) @@ -68,7 +72,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: sequence_parallel=self.sequence_parallel, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py index 8f5de69..f6fdb75 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py @@ -1,8 +1,10 @@ -import torch +# Third Party from torch.distributed._tensor.placement_types import Replicate, Shard from transformers import DynamicCache from transformers.modeling_outputs import MoeCausalLMOutputWithPast +import torch +# Local from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor from ..dense_TP import CausalLMModelMixin_TP from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss @@ -27,18 +29,20 @@ def forward( max_seqlen: torch.Tensor | None = None, output_router_logits: bool | None = None, ) -> tuple | MoeCausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer( @@ -60,9 +64,13 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + lm_loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) aux_loss = tensor_to_dtensor( - transformer_outputs.aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate() + transformer_outputs.aux_loss, + device_mesh=self.tp_mesh, + current_placement=Replicate(), ) if lm_loss is None: @@ -75,8 +83,12 @@ def forward( else: if self.tensor_parallel_word_embeddings: # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + lm_logits = tensor_to_dtensor( + lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1) + ) + lm_logits = dtensor_to_tensor( + lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate() + ) return MoeCausalLMOutputWithPast( loss=loss, diff --git a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py index 0ddd148..bade858 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py @@ -1,10 +1,11 @@ +# Third Party from transformers import AutoConfig +# Local from .bigcode import export_to_huggingface_bigcode, import_from_huggingface_bigcode from .granite import export_to_huggingface_granite, import_from_huggingface_granite from .llama import export_to_huggingface_llama, import_from_huggingface_llama - _MODEL_IMPORT_FUNCTIONS = { "gpt_bigcode": import_from_huggingface_bigcode, "granite": import_from_huggingface_granite, @@ -17,7 +18,9 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) model_type = config.model_type if model_type not in _MODEL_IMPORT_FUNCTIONS: - raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") + raise NotImplementedError( + f"the current model_type ({model_type}) is not yet supported" + ) import_function = _MODEL_IMPORT_FUNCTIONS[model_type] import_function(pretrained_model_name_or_path, save_path) @@ -30,9 +33,13 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) } -def export_to_huggingface(pretrained_model_name_or_path: str, save_path: str, model_type: str) -> None: +def export_to_huggingface( + pretrained_model_name_or_path: str, save_path: str, model_type: str +) -> None: if model_type not in _MODEL_EXPORT_FUNCTIONS: - raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") + raise NotImplementedError( + f"the current model_type ({model_type}) is not yet supported" + ) export_function = _MODEL_EXPORT_FUNCTIONS[model_type] export_function(pretrained_model_name_or_path, save_path) diff --git a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py index 9ee9339..5906aa4 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py @@ -1,12 +1,23 @@ +# Standard import shutil -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, GPTBigCodeConfig, GPTBigCodeForCausalLM +# Third Party +from transformers import ( + AutoConfig, + AutoTokenizer, + GenerationConfig, + GPTBigCodeConfig, + GPTBigCodeForCausalLM, +) +# Local from ..enums import AttentionHeadType, PositionEmbeddingType from ..models import GPTDolomiteConfig -def import_from_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: +def import_from_huggingface_bigcode( + pretrained_model_name_or_path: str, save_path: str +) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) original_config: GPTBigCodeConfig = AutoConfig.from_pretrained(save_path) @@ -23,7 +34,9 @@ def import_from_huggingface_bigcode(pretrained_model_name_or_path: str, save_pat pass -def _import_config_from_huggingface(original_config: GPTBigCodeConfig) -> GPTDolomiteConfig: +def _import_config_from_huggingface( + original_config: GPTBigCodeConfig, +) -> GPTDolomiteConfig: assert original_config.activation_function in ["gelu_pytorch_tanh", "gelu"] config = GPTDolomiteConfig( @@ -52,7 +65,9 @@ def _import_config_from_huggingface(original_config: GPTBigCodeConfig) -> GPTDol return config -def export_to_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: +def export_to_huggingface_bigcode( + pretrained_model_name_or_path: str, save_path: str +) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) config: GPTDolomiteConfig = AutoConfig.from_pretrained(save_path) @@ -72,8 +87,14 @@ def export_to_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GPTBigCodeConfig: assert config.activation_function == "gelu_pytorch_tanh" assert config.normalization_function == "layernorm" - assert AttentionHeadType(config.attention_head_type) in [AttentionHeadType.mha, AttentionHeadType.mqa] - assert PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute + assert AttentionHeadType(config.attention_head_type) in [ + AttentionHeadType.mha, + AttentionHeadType.mqa, + ] + assert ( + PositionEmbeddingType(config.position_embedding_type) + == PositionEmbeddingType.learned_absolute + ) assert config.m_emb is None assert config.m_residual is None assert config.m_width is None diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granite.py b/src/instructlab/dolomite/hf_models/model_conversion/granite.py index c9af0d6..0d9fd63 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/granite.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/granite.py @@ -1,19 +1,28 @@ +# Third Party from transformers import AutoConfig, AutoTokenizer, GenerationConfig +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..models import GPTDolomiteConfig -from .llama import _export_state_dict_to_huggingface, _import_state_dict_from_huggingface - +from .llama import ( + _export_state_dict_to_huggingface, + _import_state_dict_from_huggingface, +) try: + # Third Party from transformers import GraniteConfig, GraniteForCausalLM except: GraniteConfig = None -def import_from_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_granite( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -36,7 +45,9 @@ def import_from_huggingface_granite(pretrained_model_name_or_path: str, save_pat tokenizer.save_pretrained(save_path, legacy_format=False) -def _import_config_from_huggingface(original_config: GraniteConfig) -> GPTDolomiteConfig: +def _import_config_from_huggingface( + original_config: GraniteConfig, +) -> GPTDolomiteConfig: assert original_config.hidden_act == "silu" if original_config.num_attention_heads == original_config.num_key_value_heads: @@ -71,20 +82,32 @@ def _import_config_from_huggingface(original_config: GraniteConfig) -> GPTDolomi bos_token_id=original_config.bos_token_id, eos_token_id=original_config.eos_token_id, pad_token_id=original_config.pad_token_id, - m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, - m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, - m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + m_emb=None + if original_config.embedding_multiplier == 1 + else original_config.embedding_multiplier, + m_residual=None + if original_config.residual_multiplier == 1 + else original_config.residual_multiplier, + m_width=None + if original_config.logits_scaling == 1 + else original_config.logits_scaling, attention_multiplier=original_config.attention_multiplier, ) return config -def export_to_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: - config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_granite( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -119,7 +142,9 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GraniteConfig: num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py index 478abac..5be1796 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py @@ -1,6 +1,8 @@ -import torch +# Third Party from transformers import AutoConfig, AutoTokenizer, GenerationConfig +import torch +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..modeling_utils import ( @@ -9,15 +11,19 @@ ) from ..models import MoEDolomiteConfig - try: + # Third Party from transformers import GraniteMoeConfig, GraniteMoeForCausalLM except: GraniteMoeConfig = None -def import_from_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_granitemoe( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -41,7 +47,9 @@ def import_from_huggingface_granitemoe(pretrained_model_name_or_path: str, save_ tokenizer.save_pretrained(save_path, legacy_format=False) -def _import_config_from_huggingface(original_config: GraniteMoeConfig) -> MoEDolomiteConfig: +def _import_config_from_huggingface( + original_config: GraniteMoeConfig, +) -> MoEDolomiteConfig: assert original_config.hidden_act == "silu" if original_config.num_attention_heads == original_config.num_key_value_heads: @@ -80,9 +88,15 @@ def _import_config_from_huggingface(original_config: GraniteMoeConfig) -> MoEDol bos_token_id=original_config.bos_token_id, eos_token_id=original_config.eos_token_id, pad_token_id=original_config.pad_token_id, - m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, - m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, - m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + m_emb=None + if original_config.embedding_multiplier == 1 + else original_config.embedding_multiplier, + m_residual=None + if original_config.residual_multiplier == 1 + else original_config.residual_multiplier, + m_width=None + if original_config.logits_scaling == 1 + else original_config.logits_scaling, attention_multiplier=original_config.attention_multiplier, ) @@ -99,24 +113,36 @@ def _import_state_dict_from_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), + "transformer.wte.weight": safetensors_weights_manager.get_tensor( + "model.embed_tokens.weight" + ), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor( + "model.norm.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.moe.gate.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight" - ).T.contiguous() + state_dict[f"transformer.h.{layer_idx}.moe.gate.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight" + ).T.contiguous() + ) state_dict[f"transformer.h.{layer_idx}.moe.c_fc.weight"] = ( _split_and_reorder_for_glu( @@ -128,32 +154,50 @@ def _import_state_dict_from_huggingface( .contiguous() ) state_dict[f"transformer.h.{layer_idx}.moe.c_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight") + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight" + ) .transpose(0, 1) .contiguous() ) - state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = ( + interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ) ) return state_dict -def export_to_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: - config: MoEDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_granitemoe( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: MoEDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -190,7 +234,9 @@ def _export_config_to_huggingface(config: MoEDolomiteConfig) -> GraniteMoeConfig num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, @@ -227,45 +273,71 @@ def _export_state_dict_to_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor( + "transformer.wte.weight" + ), + "model.norm.weight": safetensors_weights_manager.get_tensor( + "transformer.ln_f.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" + ) ) state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_2.weight" + ) ) state_dict[f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.gate.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.gate.weight" + ) ).T.contiguous() - state_dict[f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight"] = _split_and_reorder_for_glu( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_fc.weight").transpose(0, 1) - ).contiguous() - state_dict[f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_proj.weight").transpose(0, 1) + state_dict[f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight"] = ( + _split_and_reorder_for_glu( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.c_fc.weight" + ).transpose(0, 1) + ).contiguous() + ) + state_dict[ + f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight" + ] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.c_proj.weight" + ).transpose(0, 1) ).contiguous() - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_weight, key_weight, value_weight = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" + ) ) return state_dict diff --git a/src/instructlab/dolomite/hf_models/model_conversion/llama.py b/src/instructlab/dolomite/hf_models/model_conversion/llama.py index dee5dd4..c94df33 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/llama.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/llama.py @@ -1,5 +1,13 @@ -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM +# Third Party +from transformers import ( + AutoConfig, + AutoTokenizer, + GenerationConfig, + LlamaConfig, + LlamaForCausalLM, +) +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..modeling_utils import ( @@ -7,11 +15,18 @@ split_query_key_value_tensor_for_attention, ) from ..models import GPTDolomiteConfig -from ..models.gpt_dolomite import interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp +from ..models.gpt_dolomite import ( + interleave_up_gate_tensor_for_mlp, + split_up_gate_tensor_for_mlp, +) -def import_from_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_llama( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -83,54 +98,100 @@ def _import_state_dict_from_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), + "transformer.wte.weight": safetensors_weights_manager.get_tensor( + "model.embed_tokens.weight" + ), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor( + "model.norm.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"), + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = ( + interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ), + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ), + ) ) if f"model.layers.{layer_idx}.mlp.up_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.bias"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.bias"), + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = ( + interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.up_proj.bias" + ), + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.gate_proj.bias" + ), + ) ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.weight" + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ) ) - if f"model.layers.{layer_idx}.mlp.down_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.bias" + if ( + f"model.layers.{layer_idx}.mlp.down_proj.bias" + in safetensors_weights_manager + ): + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.bias" + ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = ( + interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) - if f"model.layers.{layer_idx}.self_attn.q_proj.bias" in safetensors_weights_manager: + if ( + f"model.layers.{layer_idx}.self_attn.q_proj.bias" + in safetensors_weights_manager + ): state_dict[f"transformer.h.{layer_idx}.attn.c_attn.bias"] = ( interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.bias"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.bias"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.bias"), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.bias" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.bias" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.bias" + ), num_heads, num_key_value_heads, head_dim, @@ -138,22 +199,35 @@ def _import_state_dict_from_huggingface( ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ) ) - if f"model.layers.{layer_idx}.self_attn.o_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.bias" + if ( + f"model.layers.{layer_idx}.self_attn.o_proj.bias" + in safetensors_weights_manager + ): + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.bias" + ) ) return state_dict -def export_to_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: - config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_llama( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -192,7 +266,9 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> LlamaConfig: num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, @@ -221,71 +297,101 @@ def _export_state_dict_to_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor( + "transformer.wte.weight" + ), + "model.norm.weight": safetensors_weights_manager.get_tensor( + "transformer.ln_f.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" + ) ) state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_2.weight" + ) ) up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_fc.weight" + ) ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight if f"transformer.h.{layer_idx}.mlp.c_fc.bias" in safetensors_weights_manager: up_bias, gate_bias = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.bias") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_fc.bias" + ) ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.bias"] = up_bias state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.bias"] = gate_bias - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.weight" + ) ) if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weights_manager: - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.bias" + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.bias" + ) ) - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_weight, key_weight, value_weight = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight if f"transformer.h.{layer_idx}.attn.c_attn.bias" in safetensors_weights_manager: - query_bias, key_bias, value_bias = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.bias"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_bias, key_bias, value_bias = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.bias" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.bias"] = query_bias state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.bias"] = key_bias state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.bias"] = value_bias - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" + ) ) if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weights_manager: - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.bias" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.bias" + ) ) return state_dict diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py index 92aea83..f29f9cd 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py @@ -1,3 +1,4 @@ +# Local from .activations import get_activation_function, is_glu from .attention import ( SDPA, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py index 478c5dd..8cf8873 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py @@ -1,5 +1,7 @@ +# Third Party import torch.nn as nn +# Local from .base import get_base_activation from .glu import get_glu_activation, is_glu diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py index 3a8d155..f58cd32 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py @@ -1,6 +1,6 @@ -import torch.nn as nn +# Third Party from transformers.activations import ACT2CLS, ClassInstantier - +import torch.nn as nn _BASE_ACTIVATIONS = { "celu": nn.modules.CELU, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py index 1419488..a59cd0e 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py @@ -1,9 +1,10 @@ +# Third Party import torch import torch.nn as nn +# Local from .base import get_base_activation - _GLU_BASE_MAPPING = { "ceglu": "celu", "eglu": "elu", diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py index c743985..46cf7b1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py @@ -1,7 +1,10 @@ +# Standard import inspect +# Third Party import torch +# Local from ...config import CommonConfig from ...enums import AttentionHeadType from .base import Attention @@ -18,7 +21,6 @@ split_query_key_value_tensor_for_mqa, ) - _ATTENTION_MODULES = { "eager": Attention, "sdpa": SDPA, @@ -69,7 +71,9 @@ def interleave_query_key_value_tensor_for_attention( ) -> torch.Tensor: if attention_head_type.value in _INTERLEAVE_FUNCTIONS: interleave_function = _INTERLEAVE_FUNCTIONS[attention_head_type.value] - interleave_function_parameters = inspect.signature(interleave_function).parameters.keys() + interleave_function_parameters = inspect.signature( + interleave_function + ).parameters.keys() parameters_to_pass = {} this_function_parameters = locals() diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py index 51903f1..7edf15e 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py @@ -1,10 +1,13 @@ +# Standard import math +# Third Party +from transformers import DynamicCache import torch import torch.nn as nn import torch.nn.functional as F -from transformers import DynamicCache +# Local from ...config import CommonConfig from ...enums import AttentionHeadType, InitMethod, PositionEmbeddingType from ...utils import divide_if_divisible @@ -14,7 +17,9 @@ class Attention(nn.Module): - def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = None) -> None: + def __init__( + self, config: CommonConfig, causal: bool, layer_idx: int | None = None + ) -> None: super().__init__() self.causal = causal @@ -36,7 +41,9 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N self.attention_head_type = AttentionHeadType(config.attention_head_type) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self.scale_attn_weights = config.scale_attn_weights self.attention_multiplier = config.attention_multiplier @@ -64,9 +71,13 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values" + assert ( + self.num_key_value_heads == 1 + ), f"{self.__class__.__name__} should have 1 head for keys and values" else: - raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") + raise ValueError( + f"unexpected attention_head_type ({self.attention_head_type})" + ) # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features @@ -83,15 +94,23 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N std = initializer_range / math.sqrt(2 * n_layer) if init_method == InitMethod.mup: std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(self.hidden_size, self.hidden_size, bias=self.add_bias, std=std) + self.c_proj = ParameterizedLinear( + self.hidden_size, self.hidden_size, bias=self.add_bias, std=std + ) self.attn_pdrop = config.attn_pdrop self.resid_pdrop = config.resid_pdrop - self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) - self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) + self.attn_dropout = ( + nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) + ) + self.resid_dropout = ( + nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) + ) - def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _prepare_qkv_for_forward( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) # ========================================================================================== @@ -111,7 +130,9 @@ def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> tuple[torch.T elif self.attention_head_type == AttentionHeadType.mqa: query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states) else: - raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") + raise ValueError( + f"unexpected attention_head_type ({self.attention_head_type})" + ) # ========================================================================================== # query -> (batch_size, num_heads, query_length, head_dim) @@ -138,10 +159,17 @@ def _prepare_qkv_for_forward_gqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) + hidden_states = hidden_states.view( + batch_size, query_length, self.num_key_value_heads, -1 + ) query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ( + (self.num_heads // self.num_key_value_heads) * self.head_dim, + self.head_dim, + self.head_dim, + ), + dim=-1, ) # this needs to be a reshape instead of view sadly @@ -158,7 +186,9 @@ def _prepare_qkv_for_forward_mqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) + query, key, value = hidden_states.split( + (self.hidden_size, self.head_dim, self.head_dim), dim=-1 + ) query = query.view(batch_size, query_length, self.num_heads, -1) @@ -233,16 +263,20 @@ def forward( if attention_mask is None: attn_weights = torch.empty( - (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype + (batch_size * self.num_heads, query_length, key_length), + device=query.device, + dtype=query.dtype, ) beta = 0 else: - attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape(-1, query_length, key_length) + attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape( + -1, query_length, key_length + ) beta = 1 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False)).view( - batch_size, self.num_heads, query_length, key_length - ) + attn_weights = torch.baddbmm( + attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False) + ).view(batch_size, self.num_heads, query_length, key_length) # ========================================================================================== # attn_weights -> (batch_size, num_heads, query_length, key_length) @@ -263,7 +297,9 @@ def forward( # ========================================================================================== attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) + attn_output = attn_output.reshape( + batch_size, -1, self.num_heads * self.head_dim + ) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py index 26bac53..9c44393 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py @@ -1,7 +1,9 @@ -import torch +# Third Party from transformers import DynamicCache from transformers.modeling_flash_attention_utils import _flash_attention_forward +import torch +# Local from ...enums import AttentionHeadType, PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py index 9b07a51..6338afc 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py @@ -1,13 +1,15 @@ -import torch +# Third Party from transformers import DynamicCache +import torch +# Local from ....utils import is_flash_attention_available from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention - if is_flash_attention_available(): + # Third Party from flash_attn.flash_attn_interface import flash_attn_varlen_func @@ -94,7 +96,12 @@ def _prepare_qkv_for_forward_gqa( hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ( + (self.num_heads // self.num_key_value_heads) * self.head_dim, + self.head_dim, + self.head_dim, + ), + dim=-1, ) # this needs to be a reshape instead of view sadly @@ -107,7 +114,9 @@ def _prepare_qkv_for_forward_mqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: total_q = hidden_states.shape[0] - query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) + query, key, value = hidden_states.split( + (self.hidden_size, self.head_dim, self.head_dim), dim=-1 + ) query = query.view(total_q, self.num_heads, -1) key = key.unsqueeze(1) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py index ad3290e..8188ee1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py @@ -1,7 +1,9 @@ +# Third Party +from transformers import DynamicCache import torch import torch.nn.functional as F -from transformers import DynamicCache +# Local from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention @@ -71,7 +73,9 @@ def forward( batch_size = attn_output.shape[0] attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) + attn_output = attn_output.reshape( + batch_size, -1, self.num_heads * self.head_dim + ) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py index ca60ca9..275fe56 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py @@ -1,3 +1,4 @@ +# Third Party import torch @@ -61,14 +62,21 @@ def interleave_query_key_value_tensor_for_gqa( def split_query_key_value_tensor_for_gqa( - query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int, head_dim: int + query_key_value_weight: torch.Tensor, + num_heads: int, + num_key_value_heads: int, + head_dim: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads original_shape = query_key_value_weight.shape - query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1) + query_key_value_weight = query_key_value_weight.view( + num_key_value_heads, (query_heads_per_group + 2), -1 + ) - query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) + query_weight, key_weight, value_weight = query_key_value_weight.split( + (query_heads_per_group, 1, 1), 1 + ) query_weight = query_weight.reshape(-1, *original_shape[1:]) key_weight = key_weight.reshape(-1, *original_shape[1:]) @@ -92,7 +100,9 @@ def split_query_key_value_tensor_for_mqa( return query_key_value_weight.split((num_heads * head_dim, head_dim, head_dim)) -def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor: +def repeat_key_value( + x: torch.Tensor, num_heads: int, num_key_value_heads: int +) -> torch.Tensor: num_groups = num_heads // num_key_value_heads if num_groups == 1: diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py index 3cff32e..806a588 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py index 560e100..524b9c7 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py index b4bf746..edb0856 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py @@ -1,9 +1,10 @@ +# Third Party import torch.nn as nn +# Local from .layernorm import get_layernorm from .rmsnorm import get_rmsnorm - _NORMALIZATION_FUNCTIONS = { "layernorm": get_layernorm, "rmsnorm": get_rmsnorm, @@ -18,7 +19,11 @@ def get_normalization_function( ) -> nn.LayerNorm: if name in _NORMALIZATION_FUNCTIONS: return _NORMALIZATION_FUNCTIONS[name]( - normalized_shape, eps=eps, normalization_implementation=normalization_implementation + normalized_shape, + eps=eps, + normalization_implementation=normalization_implementation, ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py index 915c7ca..95ef207 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py @@ -1,9 +1,10 @@ +# Third Party import torch.nn as nn +# Local from .apex import ApexLayerNorm from .apex_persistent import ApexPersistentLayerNorm - _LAYERNORM_MODULES = { "torch": nn.LayerNorm, "apex": ApexLayerNorm, @@ -17,6 +18,10 @@ def get_layernorm( normalization_implementation: str = "torch", ) -> nn.LayerNorm: if normalization_implementation in _LAYERNORM_MODULES: - return _LAYERNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + return _LAYERNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py index 763ad7f..5d60023 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py @@ -1,9 +1,11 @@ +# Third Party import torch import torch.nn as nn def is_apex_layernorm_available() -> bool: try: + # Third Party from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction return True @@ -12,14 +14,21 @@ def is_apex_layernorm_available() -> bool: if is_apex_layernorm_available(): + # Third Party from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction def apex_layernorm( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient: bool + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool, ) -> torch.Tensor: normalized_shape = (input.shape[-1],) - return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps, memory_efficient) + return FusedLayerNormAffineFunction.apply( + input, weight, bias, normalized_shape, eps, memory_efficient + ) class ApexLayerNorm(nn.LayerNorm): diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py index e3ac497..40ec27c 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py @@ -1,9 +1,11 @@ +# Third Party import torch import torch.nn as nn def is_apex_persistent_layernorm_available() -> bool: try: + # Third Party from apex.contrib.layer_norm.layer_norm import FastLayerNormFN return True @@ -12,6 +14,7 @@ def is_apex_persistent_layernorm_available() -> bool: if is_apex_persistent_layernorm_available(): + # Third Party from apex.contrib.layer_norm.layer_norm import FastLayerNormFN @@ -44,7 +47,11 @@ def is_apex_persistent_layernorm_available() -> bool: def apex_persistent_layernorm( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient, ) -> torch.Tensor: return FastLayerNormFN.apply(input, weight, bias, eps, memory_efficient) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py index 42a64c3..a7c7dc1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py @@ -1,11 +1,17 @@ +# Third Party import torch.nn as nn +# Local from .apex import ApexRMSNorm from .base import RMSNorm -#from .torchtitan import TorchTitanRMSNorm + +# from .torchtitan import TorchTitanRMSNorm # Removing TorchTitanRMSNorm to avoid unecessary imports and checks -_RMSNORM_MODULES = {"torch": RMSNorm, "apex": ApexRMSNorm}#, "torchtitan": TorchTitanRMSNorm} +_RMSNORM_MODULES = { + "torch": RMSNorm, + "apex": ApexRMSNorm, +} # , "torchtitan": TorchTitanRMSNorm} def get_rmsnorm( @@ -14,6 +20,10 @@ def get_rmsnorm( normalization_implementation: str = "torch", ) -> nn.LayerNorm: if normalization_implementation in _RMSNORM_MODULES: - return _RMSNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + return _RMSNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py index c91f4e7..2c2d646 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py @@ -1,10 +1,14 @@ +# Third Party import torch import torch.nn as nn def is_apex_rmsnorm_available() -> bool: try: - from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + # Third Party + from apex.normalization.fused_layer_norm import ( + FusedRMSNormAffineMixedDtypesFunction, + ) return True except ImportError: @@ -12,12 +16,19 @@ def is_apex_rmsnorm_available() -> bool: if is_apex_rmsnorm_available(): - from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + # Third Party + from apex.normalization.fused_layer_norm import ( + FusedRMSNormAffineMixedDtypesFunction, + ) -def apex_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float, memory_efficient: bool) -> torch.Tensor: +def apex_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float, memory_efficient: bool +) -> torch.Tensor: normalized_shape = (input.shape[-1],) - return FusedRMSNormAffineMixedDtypesFunction.apply(input, weight, normalized_shape, eps, memory_efficient) + return FusedRMSNormAffineMixedDtypesFunction.apply( + input, weight, normalized_shape, eps, memory_efficient + ) class ApexRMSNorm(nn.RMSNorm): diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py index 82dd4a2..8f45676 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py index c5fd754..38bd5af 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py @@ -10,16 +10,18 @@ """Code taken from torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py""" - +# Standard import math +# Third Party import torch import torch.nn as nn +# Local from .....utils import is_triton_available - if is_triton_available(): + # Third Party import triton import triton.language as tl @@ -113,7 +115,9 @@ def _rms_norm_bwd_kernel_sm( for row in range(row_start, row_end): # Load input, output gradient, and reciprocal standard deviation x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) - dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to( + tl.float32 + ) rstd = tl.load(Rstd + row) # Compute normalized input and gradients @@ -153,7 +157,9 @@ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Ten raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (M,) - _rms_norm_fwd_kernel[grid](x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N) + _rms_norm_fwd_kernel[grid]( + x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N + ) ctx.eps = eps ctx.save_for_backward(x, weight, rstd) @@ -189,14 +195,29 @@ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: grid = lambda meta: (sm_count,) _rms_norm_bwd_kernel_sm[grid]( - x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, _dw, eps, M, N, rows_per_sm, block_N + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, ) dw = _dw.sum(0).to(weight.dtype) dx = dx.view(x_shape_start) return dx, dw, None -def torchtitan_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: +def torchtitan_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float +) -> torch.Tensor: return _TorchTitanRMSNorm.apply(input, weight, eps) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py index e82f7cf..1e6fb4f 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py @@ -1,2 +1,3 @@ +# Local from .alibi import Alibi from .rope import RoPE, YaRNScaledRoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py index 3f49177..585bf98 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py @@ -1,5 +1,7 @@ +# Standard import math +# Third Party import torch import torch.nn as nn @@ -21,24 +23,40 @@ def forward( ) -> torch.Tensor: if attention_mask is None: arange_tensor = ( - torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1) + torch.arange(key_length, device=device) + .unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, -1, -1) ) else: - arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1) + arange_tensor = ( + (attention_mask.cumsum(dim=-1) - 1) + .masked_fill_(attention_mask == 0, 0) + .unsqueeze(1) + ) alibi = self.slopes.unsqueeze(1) * arange_tensor return alibi.to(dtype) def reset_parameters(self) -> None: closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != self.num_heads: - extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32) + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min( + closest_power_of_2, self.num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32 + ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) self.register_buffer("slopes", slopes, persistent=False) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py index 71c5916..5a3d0d3 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py @@ -1,7 +1,9 @@ """Logic is copied from transformers.models.llama.modeling_utils with slight modifications""" +# Standard import math +# Third Party import torch import torch.nn as nn @@ -22,7 +24,9 @@ def __init__( self.reset_parameters() - def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, seq_len: int, dtype: torch.dtype, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) @@ -32,10 +36,14 @@ def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> tup return cos, sin def reset_parameters(self) -> None: - self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=None, dtype=torch.float32) + self._set_cos_sin_cache( + seq_len=self.max_position_embeddings, device=None, dtype=torch.float32 + ) @torch.no_grad() - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: + def _set_cos_sin_cache( + self, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> None: self.max_seq_len_cached = seq_len inv_freq = self._get_inv_freq(device) @@ -46,12 +54,20 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dt # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False) - self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False) + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False + ) def _get_inv_freq(self, device: torch.device) -> torch.Tensor: return 1.0 / ( - self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / self.head_dim) + self.base + ** ( + torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) + / self.head_dim + ) ) @@ -86,17 +102,27 @@ def __init__( self.reset_parameters() def _get_inv_freq(self, device: torch.device) -> torch.Tensor: - pos_freqs = self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) + pos_freqs = self.base ** ( + torch.arange(0, self.head_dim, 2).float() / self.head_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) low, high = _yarn_find_correction_range( - self.beta_fast, self.beta_slow, self.head_dim, self.base, self.original_max_position_embeddings + self.beta_fast, + self.beta_slow, + self.head_dim, + self.base, + self.original_max_position_embeddings, ) inv_freq_mask = ( - 1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float() - ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + (1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float()) + * self.extrapolation_factor + ) # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq @@ -118,15 +144,25 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _yarn_find_correction_dim( num_rotations: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 ) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations def _yarn_find_correction_range( - low_rot: int, high_rot: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 + low_rot: int, + high_rot: int, + dim: int, + base: int = 10000, + max_position_embeddings: int = 2048, ) -> int: - low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case diff --git a/src/instructlab/dolomite/hf_models/models/__init__.py b/src/instructlab/dolomite/hf_models/models/__init__.py index 871910e..8eb2025 100644 --- a/src/instructlab/dolomite/hf_models/models/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/__init__.py @@ -2,4 +2,4 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .gpt_dolomite import GPTDolomiteForCausalLM, GPTDolomiteModel, GPTDolomiteConfig +from .gpt_dolomite import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py index 347102e..07f56fb 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py @@ -1,3 +1,4 @@ +# Local from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel from .config import GPTDolomiteConfig from .main import GPTDolomiteForCausalLM diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py index c9bee9d..5c9a4b6 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py @@ -1,3 +1,4 @@ +# Local from ...mixins import BaseModelMixin, PreTrainedModelMixin from .config import GPTDolomiteConfig from .layer import GPTDolomiteBlock diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py index 8b83592..8015ae2 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py @@ -1,3 +1,4 @@ +# Local from ...config import CommonConfig diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py index 5fc15a5..7f9c954 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py @@ -1,7 +1,9 @@ +# Third Party +from transformers import DynamicCache import torch import torch.nn as nn -from transformers import DynamicCache +# Local from ...enums import AttentionHeadType from ...modeling_utils import get_attention_module, get_normalization_function from .config import GPTDolomiteConfig @@ -36,7 +38,11 @@ def __init__( normalization_implementation=normalization_implementation, ) self.attn = get_attention_module( - config, True, attention_implementation, use_padding_free_transformer, layer_idx + config, + True, + attention_implementation, + use_padding_free_transformer, + layer_idx, ) self.ln_2 = get_normalization_function( config.normalization_function, diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py index cba1599..1e7527f 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py @@ -1,3 +1,4 @@ +# Local from ...mixins import CausalLMModelMixin from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py index b94e41a..e08cbcc 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py @@ -1,8 +1,11 @@ +# Standard import math +# Third Party import torch import torch.nn as nn +# Local from ...enums import InitMethod from ...modeling_utils import ParameterizedLinear, get_activation_function, is_glu from .config import GPTDolomiteConfig @@ -38,9 +41,13 @@ def __init__(self, config: GPTDolomiteConfig) -> None: std = initializer_range / math.sqrt(2 * n_layer) if init_method == InitMethod.mup: std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(intermediate_size, hidden_size, bias=add_bias, std=std) + self.c_proj = ParameterizedLinear( + intermediate_size, hidden_size, bias=add_bias, std=std + ) - self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) + self.dropout = ( + nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.c_fc(hidden_states) @@ -50,9 +57,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def interleave_up_gate_tensor_for_mlp(up_weight: torch.Tensor, gate_weight: torch.Tensor) -> torch.Tensor: +def interleave_up_gate_tensor_for_mlp( + up_weight: torch.Tensor, gate_weight: torch.Tensor +) -> torch.Tensor: return torch.cat([up_weight, gate_weight]) -def split_up_gate_tensor_for_mlp(c_fc_weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def split_up_gate_tensor_for_mlp( + c_fc_weight: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: return c_fc_weight.chunk(2) diff --git a/src/instructlab/dolomite/hf_models/register_hf.py b/src/instructlab/dolomite/hf_models/register_hf.py index e92e456..426cd3b 100644 --- a/src/instructlab/dolomite/hf_models/register_hf.py +++ b/src/instructlab/dolomite/hf_models/register_hf.py @@ -1,11 +1,13 @@ -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM - -from .models import ( - GPTDolomiteConfig, - GPTDolomiteForCausalLM, - GPTDolomiteModel, +# Third Party +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, ) +# Local +from .models import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel # (AutoConfig, AutoModel, AutoModelForCausalLM) _CUSTOM_MODEL_REGISTRY = [ @@ -16,7 +18,11 @@ def register_model_classes() -> None: - for config_class, auto_model_class, auto_model_for_causal_lm_class in _CUSTOM_MODEL_REGISTRY: + for ( + config_class, + auto_model_class, + auto_model_for_causal_lm_class, + ) in _CUSTOM_MODEL_REGISTRY: model_type = config_class.model_type AutoConfig.register(model_type, config_class) @@ -27,5 +33,11 @@ def register_model_classes() -> None: _CUSTOM_MODEL_CLASSES.append(auto_model_for_causal_lm_class) -def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool: - return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES +def is_custom_model( + model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], + model_type: str, +) -> bool: + return ( + model_class.__name__ in _CUSTOM_MODEL_CLASSES + or model_type in _CUSTOM_MODEL_TYPES + ) diff --git a/src/instructlab/dolomite/hf_models/utils.py b/src/instructlab/dolomite/hf_models/utils.py index d6ae749..e66cea0 100644 --- a/src/instructlab/dolomite/hf_models/utils.py +++ b/src/instructlab/dolomite/hf_models/utils.py @@ -1,3 +1,4 @@ +# Third Party import torch @@ -25,13 +26,18 @@ def convert_padding_free_lists_to_tensors( labels: list[list[int]] | None = None, device: torch.device = None, ) -> tuple[torch.Tensor]: - # check input types are correct error_message = "{variable} should be of type List[List[{dtype}]]" _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) - _check_list_type(inputs_embeds, error_message.format(variable="inputs_embeds", dtype="float")) - _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) - _check_list_type(token_type_ids, error_message.format(variable="token_type_ids", dtype="int")) + _check_list_type( + inputs_embeds, error_message.format(variable="inputs_embeds", dtype="float") + ) + _check_list_type( + position_ids, error_message.format(variable="position_ids", dtype="int") + ) + _check_list_type( + token_type_ids, error_message.format(variable="token_type_ids", dtype="int") + ) _check_list_type(labels, error_message.format(variable="labels", dtype="int")) # prepare inputs for the model @@ -57,7 +63,9 @@ def convert_padding_free_lists_to_tensors( return input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen -def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: +def _check_list_type( + list_of_list: list[list[int | float]] | None, error_message: str +) -> None: if list_of_list is None: return diff --git a/src/instructlab/dolomite/utils/hf_hub.py b/src/instructlab/dolomite/utils/hf_hub.py index 82d3431..d85ebd9 100644 --- a/src/instructlab/dolomite/utils/hf_hub.py +++ b/src/instructlab/dolomite/utils/hf_hub.py @@ -1,11 +1,15 @@ +# Standard import os +# Third Party from transformers import AutoConfig, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file from transformers.utils.hub import get_checkpoint_shard_files -def download_repo(repo_name_or_path: str) -> tuple[AutoConfig | None, AutoTokenizer | None, str]: +def download_repo( + repo_name_or_path: str, +) -> tuple[AutoConfig | None, AutoTokenizer | None, str]: config = _download_config(repo_name_or_path) tokenizer = _download_tokenizer(repo_name_or_path) model_path = None @@ -20,7 +24,9 @@ def download_repo(repo_name_or_path: str) -> tuple[AutoConfig | None, AutoTokeni except: # try downloading model weights if they are sharded try: - sharded_filename = cached_file(repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + sharded_filename = cached_file( + repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME + ) get_checkpoint_shard_files(repo_name_or_path, sharded_filename) model_path = os.path.dirname(sharded_filename) except: diff --git a/src/instructlab/dolomite/utils/safetensors.py b/src/instructlab/dolomite/utils/safetensors.py index a9ffd0b..65a0a75 100644 --- a/src/instructlab/dolomite/utils/safetensors.py +++ b/src/instructlab/dolomite/utils/safetensors.py @@ -1,11 +1,13 @@ +# Standard import json import os -import torch +# Third Party from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import save_file from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME +import torch class SafeTensorsWeightsManager: @@ -33,7 +35,10 @@ def get_slice(self, tensor_name: str): return f.get_slice(tensor_name) def get_tensor( - self, tensor_name: str, dtype: torch.dtype | None = None, device: torch.device | None = None + self, + tensor_name: str, + dtype: torch.dtype | None = None, + device: torch.device | None = None, ) -> torch.Tensor: filename = self.tensor_filenames[tensor_name] f = self.file_handles[filename] diff --git a/tox.ini b/tox.ini index 70c88a1..57ea095 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ [tox] # py3-unit runs unit tests with 'python3' -# py311-unit runs the same tests with 'python3.11' +# py312-unit runs the same tests with 'python3.12' envlist = ruff, lint, mypy, spellcheck minversion = 4.4