Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/projected_compression/weight_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
device = get_device()
logger = logging.getLogger(__name__)


def _calculate_activations_dimension_importances(
model: nn.Module, calibration_data, dmodel, dff, n_blocks, device="cuda"
model: nn.Module, calibration_data, dmodel, dff, n_blocks, mode, device="cuda"
):
"""
Calculate importance of each neuron (dmodel and dff) using forward hooks.
"""
if mode not in ["minitron", "wanda"]:
raise ValueError(f"Unknown mode {mode} for activation dimension importances calculation.")
logger.info(f"Calculating {mode} style weight importances calculation.")

dmodel_importances = torch.zeros(dmodel, device=device)
dff_importances = torch.zeros(n_blocks, dff, device=device)

Expand All @@ -25,15 +28,18 @@ def _calculate_activations_dimension_importances(
def hook_dmodel_pre_attn(layer, inp, out):
nonlocal dmodel_importances
# inp[0] has shape [batch, seq, dmodel]
dmodel_importances += torch.sum(torch.abs(out.detach()), dim=[0, 1])
dmodel_importances += torch.sum(torch.abs(out.detach()), dim=[0, 1]) if mode == "minitron"\
else torch.sqrt(torch.sum(torch.square(out.detach()), dim=[0, 1]))

def hook_dmodel_pre_ff(layer, inp, out):
nonlocal dmodel_importances
dmodel_importances += torch.sum(torch.abs(out.detach()), dim=[0, 1])
dmodel_importances += torch.sum(torch.abs(out.detach()), dim=[0, 1]) if mode == "minitron"\
else torch.sqrt(torch.sum(torch.square(out.detach()), dim=[0, 1]))

def hook_ff_pre_act(layer, inp, out, block_idx=None):
nonlocal dff_importances
dff_importances[block_idx] += torch.sum(torch.abs(out.detach()), dim=[0, 1])
dff_importances[block_idx] += torch.sum(torch.abs(out.detach()), dim=[0, 1]) if mode == "minitron"\
else torch.sqrt(torch.sum(torch.square(out.detach()), dim=[0, 1]))

# --- Register hooks ---
for block_idx, block in enumerate(model.encoder.blocks):
Expand Down Expand Up @@ -64,6 +70,12 @@ def hook_ff_pre_act(layer, inp, out, block_idx=None):
for h in handles:
h.remove()

if mode == "wanda": # not real WandA, as it introduces sparsity at the weight matrix level, just its importance metric
dmodel_magnitudes, dff_magnitudes = _calculate_magnitude_dimension_importances(model)
dmodel_importances = dmodel_importances * dmodel_magnitudes
for i in range(n_blocks):
dff_importances[i] = dff_importances[i] * dff_magnitudes[i]

return dmodel_importances, dff_importances


Expand Down Expand Up @@ -144,7 +156,7 @@ def _calculate_magnitude_dimension_importances(model: nn.Module):
return mean_dmodel_magnitudes, dff_magnitudes


def minitron_importances(
def activation_importances(
model: nn.Module,
dataloader,
dmodel,
Expand All @@ -154,8 +166,8 @@ def minitron_importances(
total_batch_size,
n_blocks,
checkpoint_save_path,
mode="minitron",
):
logger.info(f"Calculating minitron style weight importances calculation.")
model.to(device)
world_size = int(os.environ["WORLD_SIZE"])
batch_size = total_batch_size // world_size
Expand All @@ -173,7 +185,7 @@ def minitron_importances(
calibration_data[i] = batch[:, :seq_len]

dmodel_importances, dff_importances = _calculate_activations_dimension_importances(
model, calibration_data, dmodel, dff, n_blocks
model, calibration_data, dmodel, dff, n_blocks, mode
)
logger.info(f"Calculated dimensions importances")

Expand Down
Loading