Skip to content

feat: Arbitrary-Rank Ablation (ARA)#211

Draft
p-e-w wants to merge 16 commits intomasterfrom
ara
Draft

feat: Arbitrary-Rank Ablation (ARA)#211
p-e-w wants to merge 16 commits intomasterfrom
ara

Conversation

@p-e-w
Copy link
Copy Markdown
Owner

@p-e-w p-e-w commented Mar 4, 2026

Arbitrary-Rank Ablation (ARA) is a radically new abliteration method that I've been developing for the past two months or so. I believe that it can replace all currently implemented methods in Heretic, including MPOA, once the remaining issues are worked out. Its only serious competitor at this time is @kabachuha's implementation of multi-directional refusal suppression with Self-Organizing Maps (#196).

ARA doesn't use refusal directions at all, neither a single direction like traditional abliteration, nor multiple directions like SOMA. Instead, ARA works by capturing input/output tensors at each individual transformer module using PyTorch hooks, then uses direct, unconstrained matrix optimization to modify those modules, based on an objective function that captures the essence of what we want to (and don't want to) change.

Intuitively, the objective encodes three competing optimization goals:

  1. The outputs of the module for inputs associated with "harmless" prompts should change as little as possible.
  2. The outputs of the module for inputs associated with "harmful" prompts should become as similar to those associated with "harmless" prompts as possible.
  3. The outputs of the module for inputs associated with "harmful" prompts should become as dissimilar to those previously associated with "harmful" prompts as possible. In combination with (2), this overcorrects away from the original residuals, which results in stronger steering that can overcome more complex refusal mechanisms.

Unlike other abliteration methods, this approach doesn't assume a particular rank for the refusal manifold, or that the centroid of the outputs must shift in a specific manner. This gives the optimizer more freedom to modify the matrix in the best possible way. Please see the code for implementation details.

The objective is affine-convex and the initial value (the original matrix) is already very close to the optimum, so L-BFGS makes short work of it, typically converging in 2-3 iterations. Because the matrices are optimized one-by-one, the total memory requirements are barely higher than for regular abliteration. The abliteration process takes longer, but the time per trial is still dominated by counting refusals. Combined with the fact that ARA has fewer optimizable parameters than our current approach (meaning that fewer trials are needed for good results), this might actually make ARA faster than regular abliteration.

Results

For demonstration purposes, I have processed openai/gpt-oss-20b with the exact code currently in this pull request. The result is p-e-w/gpt-oss-20b-heretic-ara-v3:

ara-results

This is dramatically better than any existing abliteration of gpt-oss-20b (see this table), with the possible exception of the brand new kabachuha/gpt-oss-20b-SOMbliterated, which has the same refusal count but higher KL divergence.

TODO

ARA isn't quite ready for mainstream use yet, but it's getting close. The remaining issues are:

  • Figure out how to obtain good values for the weight constants in steer_bad_behavior.
  • Decide whether to have separate abliteration parameter sets for each component, like in our current implementation.
  • Fix multi-GPU issues
  • Write a proper article explaining the mathematics and motivation behind ARA.

Feedback welcome!

@spikymoth
@kabachuha
@red40maxxer

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces Arbitrary-Rank Ablation (ARA), a radically new method for model steering that moves beyond traditional directional ablation. ARA leverages direct matrix optimization within individual transformer modules, guided by an objective function designed to minimize changes to harmless outputs while aggressively modifying harmful ones. This approach offers greater flexibility by not assuming a fixed refusal manifold rank, potentially leading to more robust and efficient abliteration results.

Highlights

  • Arbitrary-Rank Ablation (ARA) Introduction: Introduced a novel ablation method, Arbitrary-Rank Ablation (ARA), which uses direct, unconstrained matrix optimization on transformer modules instead of traditional directional ablation. This method aims to provide more flexible and effective model steering.
  • New Optimization Objective: Implemented a new objective function for ARA that balances three goals: preserving outputs for 'harmless' prompts, making 'harmful' prompt outputs similar to 'harmless' ones, and pushing 'harmful' prompt outputs away from their original state for stronger steering.
  • Conditional Ablation Logic: Integrated conditional logic throughout the codebase to switch between ARA and traditional ablation methods, affecting gradient handling, parameter optimization, module I/O capture, and model saving/loading.
  • Module I/O Capture with PyTorch Hooks: Developed functionality to capture input/output tensors at each transformer module using PyTorch hooks, which is crucial for the ARA optimization process.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/heretic/config.py
    • Added target_components setting to specify which model components to abliterate.
    • Added use_ara setting to toggle between ARA and traditional ablation.
  • src/heretic/main.py
    • Modified gradient enabling to be conditional on the use_ara setting.
    • Updated the main run function to conditionally calculate refusal directions (traditional) or module I/O (ARA).
    • Adjusted the objective function in Optuna to suggest different parameters based on whether use_ara is enabled.
    • Modified model reset, abliteration call, and model saving/uploading logic to be conditional on use_ara.
  • src/heretic/model.py
    • Added imports for LBFGS, RemovableHandle, Callable, and TypeAlias.
    • Defined ModuleIO type alias for storing module input/output tensors.
    • Made _apply_lora and reset_model calls conditional on use_ara.
    • Updated get_layer_modules to filter components based on settings.target_components.
    • Implemented the ara_abliterate method, which performs matrix optimization using L-BFGS based on captured module I/O.
    • Added get_module_io and get_module_io_batched methods to capture input/output tensors from model modules using PyTorch hooks.
  • src/heretic/utils.py
    • Imported Tensor from torch.
    • Added mean_distances_to_knn utility function for calculating mean Euclidean distances to k-nearest neighbors.
    • Modified get_trial_parameters to return different sets of parameters depending on whether ARA is enabled.
Activity
  • The author, p-e-w, developed and implemented the Arbitrary-Rank Ablation (ARA) method over approximately two months.
  • Demonstrated results with openai/gpt-oss-20b processed using ARA, showing significantly improved refusal counts compared to existing methods.
  • Requested feedback from specific reviewers (@spikymoth, @kabachuha, @red40maxxer).
  • Outlined remaining TODOs for ARA, including optimizing weight constants, deciding on separate abliteration parameter sets for components, and writing a mathematical article.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new abliteration method called Arbitrary-Rank Ablation (ARA). The changes are extensive, touching configuration, main application logic, and the model implementation to support this new method. The implementation uses PyTorch hooks to capture module I/O and an L-BFGS optimizer to modify module weights directly. The changes are mostly gated behind a new use_ara setting.

My feedback focuses on ensuring consistency with the repository's style guide and improving maintainability. Specifically, I've pointed out missing configuration updates, inconsistent trial parameter handling, a missing type annotation, and some minor style guide violations.

@kabachuha
Copy link
Copy Markdown

kabachuha commented Mar 4, 2026

Congratulations on the conception!

Expected all tensors to be on the same device, but got mat2 is on cuda:0, different from other tensors on
cuda:1 (when checking argument in method wrapper_CUDA_mm)

Needs fix for multiGPU

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 4, 2026

Needs fix for multiGPU

Thanks for pointing this out. I don't have a multi-GPU setup myself, but I'll rent one to figure out where the problem is.

@kabachuha
Copy link
Copy Markdown

Pareto frontier for Qwen3-4B-Instruct-2507.

Qwen series are reportedly notoriously hard to decensor, so I decided to test it.

? Which trial do you want to use? (Use arrow keys)
   [Trial 156] Refusals:  3/100, KL divergence: 2.9901
   [Trial  25] Refusals:  4/100, KL divergence: 1.1411
   [Trial 148] Refusals:  5/100, KL divergence: 0.2406
   [Trial 147] Refusals:  6/100, KL divergence: 0.2359
   [Trial 168] Refusals:  7/100, KL divergence: 0.1521
   [Trial 151] Refusals:  8/100, KL divergence: 0.1384
 » [Trial 154] Refusals: 10/100, KL divergence: 0.1345
   [Trial 170] Refusals: 17/100, KL divergence: 0.0770
   [Trial 174] Refusals: 31/100, KL divergence: 0.0749
   [Trial  73] Refusals: 45/100, KL divergence: 0.0390
   [Trial 142] Refusals: 56/100, KL divergence: 0.0237
   [Trial  86] Refusals: 87/100, KL divergence: 0.0188
   Run additional trials

Comparison with the other methods:

Model Refusals for "harmful" prompts KL divergence from original model for "harmless" prompts
Qwen/Qwen3-4B-Instruct-2507 (original) 100/100 0 (by definition)
kabachuha/Qwen3-4B-Instruct-2507-SOMbliterated 3/100 0.08
heretic-org/Qwen3-4B-Instruct-2507-heretic 5/100 0.07
Goekdeniz-Guelmez/Qwen3-4B-Instruct-2507-gabliterated 4/100 0.25
p-e-w/gpt-oss-20b-heretic-ara (this) 8/100 0.14

I'd say, the results are somewhere in-between

@kabachuha
Copy link
Copy Markdown

kabachuha commented Mar 4, 2026

@p-e-w Can you submit the gpt-oss model to the UGI leaderboard? https://huggingface.co/spaces/DontPlanToEnd/UGI-Leaderboard/discussions

@spikymoth
Copy link
Copy Markdown
Contributor

Interesting idea, and surprisingly straightfoward! A couple of initial thoughts/questions:

  1. mean_distances_to_knn returns the distance to the k nearest neighbors. That makes sense to me for maximizing dissimilarity, but wouldn't you want something more like the opposite (the k farthest neighbors) for maximizing similarity?
  2. I think the magnitude-preserving part of MPOA could be added onto this pretty easily by modifying the closure to call the objective (loss) function on a norm-constrained matrix. Something like this:
                    if self.settings.row_normalization == RowNormalization.FULL:
                        # Get row norms for original matrix.
                        target_norms = torch.norm(matrix, dim=1, keepdim=True)

                        def closure() -> Tensor:
                            optimizer.zero_grad()
                            # Compute loss relative to norm-constrained matrix.
                            constrained_matrix = F.normalize(matrix, p=2, dim=1) * target_norms
                            loss = objective(constrained_matrix)
                            # Compute the projected gradient with respect to the constrained matrix.
                            loss.backward()
                            return loss
                    else:
                        def closure() -> Tensor:
                            optimizer.zero_grad()
                            loss = objective(matrix)
                            loss.backward()
                            return loss

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 4, 2026

@kabachuha The weights in steer_bad_behavior need to be fine-tuned for each model (currently by hand). If you run it on other models you will get poor results unless you tune them. I've seen much, much better results with Qwen3-4B-Instruct-2507 during testing.

In the future, this will happen automatically via Optuna, though it's not as straightforward as it may seem because the possible ranges go over multiple orders of magnitude.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 4, 2026

@spikymoth

mean_distances_to_knn returns the distance to the k nearest neighbors. That makes sense to me for maximizing dissimilarity, but wouldn't you want something more like the opposite (the k farthest neighbors) for maximizing similarity?

I don't quite understand what you mean here. Could you explain more?

I think the magnitude-preserving part of MPOA could be added onto this pretty easily by modifying the closure to call the objective (loss) function on a norm-constrained matrix.

That's true, but I'm not convinced that preserving the magnitude is correct in general. If harmful and harmless prompts result in residuals of different magnitudes, then abliteration should change the magnitudes I think.

@spikymoth
Copy link
Copy Markdown
Contributor

mean_distances_to_knn returns the distance to the k nearest neighbors. That makes sense to me for maximizing dissimilarity, but wouldn't you want something more like the opposite (the k farthest neighbors) for maximizing similarity?

I don't quite understand what you mean here. Could you explain more?

Well, do I understand correctly that it's 1. getting the distance between each pair of vectors, 2. selecting the k smallest distances and 3. returning the mean distance?

If so, it seems to target the harmful outputs that are already most similar to the harmless outputs and push them closer, while ignoring more dissimilar outputs. I'm wondering if this creates representative differences.

Perhaps the optimal way to contrast representative differences would be to contrast the top-k SOM neurons (as the center of gravity for each cluster) for a set of outputs.

That's true, but I'm not convinced that preserving the magnitude is correct in general. If harmful and harmless prompts result in residuals of different magnitudes, then abliteration should change the magnitudes I think.

IIRC the argument is that the row norms of the weight matrix overall should stay unchanged in order to preserve between-layer interpretability, i.e. each dimension in the output is expected to have a particular activation strength and if you change it, subsequent layers may get confused about what stronger/weaker activations mean.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 4, 2026

If so, it seems to target the harmful outputs that are already most similar to the harmless outputs and push them closer, while ignoring more dissimilar outputs.

No, it targets all outputs.

It computes the mean distance to the k nearest harmless neighbors for each harmful output and then computes the mean of those means. So every harmful output is attracted towards its nearest harmless neighbors.

This is actually precisely where the strength of this method comes from, because directional ablation based on a difference of means optimizes towards a configuration where the mean of the modified harmful outputs resembles the mean of the harmless outputs. This is an unnecessary constraint that hinders finding an optimal configuration. With ARA, every single harmful output is simply attracted towards somewhere in the harmless cluster. There is no requirement that the means of the outputs align.

@spikymoth
Copy link
Copy Markdown
Contributor

Aahh OK, I was misunderstanding the operation. So for every harmful output it computes the distance from all harmless outputs, then takes the mean of the k smallest distances (nearest neighbors) to push every output toward those neighbors. That makes sense, and should naturally give more weight to directions that show up more frequently.

@red40maxxer
Copy link
Copy Markdown
Contributor

red40maxxer commented Mar 5, 2026

This is an awesome idea and I'm really looking forward to this being included in main, I'm running a bunch of tests right now to see how this performs vs standard MPOA and kabachuha's SOM technique.

Is there any possible way to get quantization to work with this? I understand that ARA does gradient-based optimization on the weight matrix and bnb would create a different shape which breaks this, but being able to use quantization with this new technique would still be very valuable IMO. Maybe we could dequantize before ARA? This might not reduce the total RAM required for abliteration but it might at least speed up inference, which can be a bottleneck.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 5, 2026

Maybe we could dequantize before ARA?

Yes, this should work. Matrices are processed one by one, so the memory impact of dequantizing an individual matrix to full precision should be relatively small.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 5, 2026

I'm running a bunch of tests right now to see how this performs vs standard MPOA and kabachuha's SOM technique.

Make sure you use the latest commit (0bb9521). If you are seeing suboptimal results, I recommend trying some combination of these:

  1. Remove mlp.down_proj from target_components.
  2. Play with the value range for steer_bad_behavior_weight.

@kabachuha
Copy link
Copy Markdown

Can you make some visualizations (ex. PCA) of the model's hidden states as the ARA method achieves convergence?

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 5, 2026

I encountered a bug, I thought I should report it, here:

L-BFGS IndexError on Windows with high steer_bad_behavior_weight

Environment:

Windows 11 Pro
Python 3.12
PyTorch 2.10.0+cu130

Issue: Trial failed with IndexError: list index out of range in torch/optim/lbfgs.py line 205 during _strong_wolfe.

Failed parameters: steer_bad_behavior_weight = 0.3967

Also saw this symlink error (possibly related): [WinError 1314] A required privilege is not held by the client: '...triton_kernels_init_.py'

What failed: L-BFGS IndexError on first trial
Environment: PowerShell (non-admin) on Windows 11
What worked: Command Prompt as Administrator
Error: IndexError: list index out of range in torch\optim\lbfgs.py line 205
Parameters that crashed: steer_bad_behavior_weight = 0.3967 (much higher than your working trials at ~0.02-0.09)

Could be:

Windows/PowerShell specific issue
Numerical instability with high steer_bad_behavior_weight values
Admin privileges needed for triton kernels

After switching from PowerShell (non-admin) to Admin CMD, error messages did not appear and trial completed successfully.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 6, 2026

Can you make some visualizations (ex. PCA) of the model's hidden states as the ARA method achieves convergence?

Yes, I will do a full writeup explaining the motivation behind ARA, which will include such data.

@GhostWithAHat
Copy link
Copy Markdown

Did some tests with Qwen 3.5 4B.

Main branch: Best trials still refuse more than 50 of 100 bad prompts.
After merging ara into main and not touching the values for the weight constants in steer_bad_behavior: Good trials refuse 4 of 100 harmfull prompts with a KLD of 0.1396.

Can't wait for somebody making an ARA Version of Qwen 3.5 27B.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 7, 2026

@kabachuha

I am unable to reproduce the multi-GPU issue. I have tried processing Gemma 3 27B (which is 55 GB in BF16) on a 2x 5090 system, forcing tensor sharding. However, I am not getting a device mismatch error like you did.

Could you give some more information about the system where this error occurred?

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 10, 2026

@joninco

I have updated the parameter ranges based on your data and my own observations. Could you please run your test again with the new ranges?

@kabachuha
Copy link
Copy Markdown

Even if the direction is determined correctly, there evidently emerges a new refusal pathway in newer models, see #221. Can be a future direction of research after ARA

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 10, 2026

@joninco

I have updated the parameter ranges based on your data and my own observations. Could you please run your test again with the new ranges?

Expanded ARA parameter ranges (c76416f) cause poor optimization results.

After commit c76416f expanded the parameter ranges, I'm seeing significantly worse results with ARA:

Running trials for 6 hours and 315 trials later, most produced results are mostly unusable and the very few usable results gave disappointing results in refusals to KL divergence ratio compared to pre-c76416f and majority of results where unusable due to crazily high KL divergences like 13.4725 rendering any sort of result worthless.

With expanded ranges, ~95% of trials produced unusable results (KL divergence as high as 13.47 or more). With old ranges, most trials stay in reasonable KL territory, allowing the optimizer to efficiently find good solutions. While with c76416f the optimizer instead wastes hundreds of trials learning to avoid these death zones instead of finding optimal solutions.

Optimizer converges to narrow layer ranges (9 layers) instead of full coverage (32+ layers) from previous runs on pre-c76416f.

The ~20x larger search space may need significantly more trials to converge properly (1000+ trials instead of ~300), so instead of getting good results after 3-6 hours of running trials, you might need 18+ hours of running trials to maybe get a good result now.

Update: After 360+ trials, the optimizer is now exploring wider layer ranges (57-58 layers vs 9 layers earlier at trial 215). This confirms the expanded space needs significantly more trials to converge properly, roughly 2x the trials to achieve similar coverage.

Update 2: After 455 trials and many hours later, I can say that almost all results are unusable and running more trials did not improve at all on the "best" disappointing result reached all the way back at trial 215.

Update 3: Expended ranges seem to be necessary to attain low refusals with ARA for certain models despite the shortcomings (hundreds of trials wasted on unusable results due to crazily high KL divergences).

Conclusion: For now while ARA is still being worked on and models behavior towards it and it's efficacy greatly varies from one model to another, consider adding a --expanded-ranges flag or [ara] expanded_ranges = true/false config option, defaulting to the old ranges for faster convergence on most models as it's need greatly varies between models and finetunes.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

@erm14254

majority of results where unusable due to crazily high KL divergences like 13.4725

hundreds of trials wasted on unusable results due to crazily high KL divergences

What you are observing is expected and not in itself indicative of a problem. Trials are not results, and exploration isn't "wasting" trials.

Here's what's going on: The region where overcorrect_relative_weight > 1.0 is necessarily unstable. That's because in that region, the objective pushes away from the bad residuals more strongly than it attracts to the good residuals, with the regularization term for the good residuals as the only counteracting force, so iteration behavior becomes essentially chaotic.

However, it is also precisely in that region where the best results can often be obtained, and the optimization process almost always finds those results in my tests. I am still exploring whether the trajectory can be tamed more efficiently, but so far it isn't clear that a change is necessary.

It doesn't matter if the optimizer spends 75% of the trials exploring parameters that yield a KLD of 15+. All that matters is the Pareto front at the end. To check whether there is an improvement or not, post the Pareto fronts for a fixed number of trials (200 should be enough, it certainly is in my tests), for the same model, before and after the commit. No other data demonstrates anything.

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 11, 2026

@erm14254

majority of results where unusable due to crazily high KL divergences like 13.4725

hundreds of trials wasted on unusable results due to crazily high KL divergences

What you are observing is expected and not in itself indicative of a problem. Trials are not results, and exploration isn't "wasting" trials.

Here's what's going on: The region where overcorrect_relative_weight > 1.0 is necessarily unstable. That's because in that region, the objective pushes away from the bad residuals more strongly than it attracts to the good residuals, with the regularization term for the good residuals as the only counteracting force, so iteration behavior becomes essentially chaotic.

However, it is also precisely in that region where the best results can often be obtained, and the optimization process almost always finds those results in my tests. I am still exploring whether the trajectory can be tamed more efficiently, but so far it isn't clear that a change is necessary.

It doesn't matter if the optimizer spends 75% of the trials exploring parameters that yield a KLD of 15+. All that matters is the Pareto front at the end. To check whether there is an improvement or not, post the Pareto fronts for a fixed number of trials (200 should be enough, it certainly is in my tests), for the same model, before and after the commit. No other data demonstrates anything.

What I was trying to say is that for certain models, the expended ranges where just "too much" and where unnecessary and where not necessary making things better and where possibly making things worse/more difficult instead of better/easy.

Testing results:

Better on pre-c76416f: Gemm3 model family, GPT OSS family

Better on post-c76416f: Qwen3.5 family (this varies per models and finetunes though), GLM-4.5 Flash (new expanded ranges are most likely necessary to reach low refusals).

I haven't been able to test more models yet.

This is why I was thinking of:

  1. Implementing some kind of settings for the end user where they could manually change these parameters depending on the model that you are abliterating.
  2. Maybe even better there would be some "abliteration profiles" saved and the end user would select the best profile to use depending on model's architecture and/or family.
  3. Maybe Heretic could use some sort of detection mechanism where the model would be analyzed on startup and then the most optimal parameters would be used as a result of this analysis, a little bit like the "determining optimal batch size" that is ran at the beginning of Heretic.

Option 2 looks nice and I would think would be the most easily implemented, but if option 3 is possible maybe this would be the best and the most practical for the end-users.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

Better on pre-c76416f: Gemm3 model family, GPT OSS family

That contradicts my own tests. GPT-OSS especially gets much better results after c76416f.

Are you sure you're comparing the Pareto fronts at the end, rather than the trials as they scroll by? Because only the results determine which is better.

We're definitely not going to introduce abliteration profiles or manual parameters. The whole point of Heretic is that it's supposed to work with all models. If there is a model where it doesn't work properly, please let me know, but again: Only the results (Pareto front) matter.

Do you have a concrete example of a model where the Pareto front was better pre-c76416f?

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 11, 2026

rather than the trials as they scroll by? Because only the results determine which is better.

I am comparing refusals and KL divergence ratio and how many results become unusable due to insanely high KL divergence. For example, something like this: KL divergence: 31.1086 is completely unusable and this is introduced by the new expended ranges, but for some models like Qwen3.5 family or GLM 4.7 Flash the new expended range are necessary to be able to reach low refusals with ARA, however with Gemma 3 family or GPT OSS family the old ranges are working just fine and are already giving great, fast results with ARA with a very good refusal to KL divergence ratio.

For some models the new expanded range are a "necessary evil" for now as it's the only way to reach low refusals, but at the same time it's very far from perfect and can give you ridiculous results such as this:

Running trial 227 of 400...

  • Parameters:
    • start_layer_index = 0
    • end_layer_index = 38
    • preserve_good_behavior_weight = 0.9543
    • steer_bad_behavior_weight = 0.1140
    • overcorrect_relative_weight = 1.1662
    • neighbor_count = 2
  • Reloading model...
    Loading weights: 100%|███████████████████████████████████████████████████████████████| 751/751 [00:17<00:00, 42.32it/s]
  • Abliterating (Arbitrary-Rank Ablation)...
  • Evaluating...
    • Obtaining first-token probability distributions...
    • KL divergence: 22.8911
    • Counting model refusals...
    • Refusals: 100/100

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

and how many results become unusable due to insanely high KL divergence

That's irrelevant. We only need one good result in the end. The high-KLD trials are NOT wasted, they are TPE exploring the parameter space, informing where to sample from next.

and can give you ridiculous results such as this

There is nothing ridiculous about that trial. It's just another step in TPE trying to understand the objective function.

I think you may be misunderstanding what the optimizer does. Black-box optimization doesn't work like gradient descent. There is no expectation that trials get monotonically better as the run proceeds.

There are usually only 2-3 trials in the Pareto front that are good: Those near the first big step down in KLD. Those trials are the only meaningful metric for comparing two methods. According to that metric, are you consistently seeing worse results post-c76416f for any model?

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 11, 2026

According to that metric, are you consistently seeing worse results post-c76416f for any model?

No it's not so black and white, post-c76416f is absolutely needed for certain models who seem harder to crack on ARA, for example qwen3.5 27B is one such model where on pre-c76416f it was hard to get low refusals at all where the best possible on the pre-c76416f parameters would give you maybe 55/100 at best.

However c76416f is not required for certain models who already give great result on the old settings without having to spend 20+ hours to do 1000+ trials because the search field has been so over-expended, for example gemma 3 is one such model.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

without having to spend 20+ hours to do 1000+ trials because the search field has been so over-expended

There should be no need to run any additional trials post-c76416f, what makes you think that?

In a previous comment you wrote

The ~20x larger search space may need significantly more trials [...]

This is again a misunderstanding of how TPE operates. From the point of view of TPE, the search space is not 20x larger, because TPE doesn't sample randomly uniformly. The GMMs effectively eliminate almost all of the parameter ranges with high probability, once a few exploratory trials have been sampled. Internally, TPE constructs a hypercube that normalizes all ranges, so the absolute width of ranges is effectively irrelevant. The main constraint governing performance is the number of parameters, not the size of their individual sampling spaces.

Edit: Changed "randomly" to "uniformly". TPE indeed samples randomly, just not from a uniform distribution.

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 11, 2026

I understand TPE isn't random sampling and the search space normalization. But from a practical standpoint:

Gemma-3 with old ranges: 4/100 refusals, 0.013 KL in ~100 trials (~1.5 hours)
Gemma-3 with new ranges: More or less similar results but needed 200+ trials to get there

The Pareto front may end up equivalent, but the time-to-result increased. For users running on consumer hardware where each trial takes 1-10 minutes per trial, that's the difference between a quick evening run and an overnight job.
I'm not saying the expanded ranges are wrong, they're clearly necessary for certain models that are harder to crack for ARA. Just noting that for "easier" models, the old ranges got there faster in practice.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

So is your mode of operation to watch the trials and stop the run the moment you see a trial that you consider good enough?

Heretic isn't really designed for that approach, but I understand that in that case, the individual trials matter.

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 11, 2026

Here's some data from GLM-4.7-Flash (MoE, 62.5GB):

With expanded ranges (261 trials), Pareto front:

Trial Refusals KL
261 6/100 0.0039
260 0/100 0.0288

This confirms expanded ranges are necessary for some models, old ranges couldn't crack GLM at all. The overcorrect_relative_weight > 1.0 region is indeed where the good results came from (trials 260-261 both around 1.0-1.1).

So I agree expanded ranges should stay as default. My earlier concern was more about time-to-result on easier models and it was not necessary for some other models who were already easily abliterable on the older ranges, but that's a minor UX issue, not a correctness issue.

So is your mode of operation to watch the trials and stop the run the moment you see a trial that you consider good enough?

More or less, yes.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 11, 2026

I just completed parallel runs (200 trials each) for Gemma 3 4B and Qwen 3 4B to get a deeper understanding of high-KLD trials. About 25% of all trials resulted in KLDs > 3.0.

My speculative theory was something like: "If overcorrect_relative_weight > 1.0 and preserve_good_behavior_weight / steer_bad_behavior_weight < some constant, then the KLD will be very high."

After looking at dozens of high-KLD and low-KLD trials, I can now say with near-certainty that such a relationship does not hold. The overcorrect_relative_weight > 1.0 region appears to be chaotic (possibly even quasi-fractal), and likely influenced by L-BFGS iteration behavior, so there is no obvious way to filter out such trials without running them.

However, and this is the important part: The results (Pareto front) after 200 trials were still excellent for both models, and high-quality parameter combinations with overcorrect_relative_weight > 1.0 were reliably found.

Conclusions:

  1. The current parameter ranges appear to give good results for all models tested so far.
  2. The "cancel early after eyeballing the trial stream" use case is not something I intend to support with extra logic.
  3. In the future, feat: make parameter ranges configurable #138 (and followup changes) will make it possible to change the parameter ranges from the configuration file, for users who want to do that.

@EugeoSynthesisThirtyTwo
Copy link
Copy Markdown

EugeoSynthesisThirtyTwo commented Mar 15, 2026

I encountered a division by zero:

(heretic) C:\Users\me\Docs\IA\heretic>heretic --model C:\Users\me\Docs\Quantization\GLM-4.7-Flash\hf
█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀  v1.2.0
█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░
▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀  https://github.com/p-e-w/heretic

Detected 2 CUDA device(s):
* GPU 0: NVIDIA GeForce RTX 5090
* GPU 1: NVIDIA GeForce RTX 5090

Loading model C:\Users\me\Docs\Quantization\GLM-4.7-Flash\hf...
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:32<00:00, 23.14it/s]
Failed (We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of the above report!)
Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:07<00:00, 101.10it/s]
Ok
* Transformer model with 47 layers
* Abliterable components:
  * attn.o_proj: 1 modules per layer
  * mlp.down_proj: 1 modules per layer

Resident system RAM: 2.71 GB
Allocated GPU VRAM: 27.99 GB
Reserved GPU VRAM: 29.82 GB

Loading good prompts from mlabonne/harmless_alpaca...
README.md: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 388/388 [00:00<00:00, 2.61MB/s]
data/train-00000-of-00001.parquet: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 972k/972k [00:00<00:00, 1.57MB/s]
data/test-00000-of-00001.parquet: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 243k/243k [00:00<00:00, 1.17MB/s]
Generating train split: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 25058/25058 [00:00<00:00, 1622853.63 examples/s]
Generating test split: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 6265/6265 [00:00<00:00, 754834.96 examples/s]
* 400 prompts loaded

Loading bad prompts from mlabonne/harmful_behaviors...
README.md: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 381/381 [00:00<00:00, 1.63MB/s]
data/train-00000-of-00001.parquet: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 15.0k/15.0k [00:00<00:00, 70.4kB/s]
data/test-00000-of-00001.parquet: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 5.46k/5.46k [00:00<00:00, 27.1kB/s]
Generating train split: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 416/416 [00:00<00:00, 50428.63 examples/s]
Generating test split: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 18049.72 examples/s]
* 400 prompts loaded

Determining optimal batch size...
* Trying batch size 1... Ok (13 tokens/s)
* Trying batch size 2... Ok (24 tokens/s)
* Trying batch size 4... Ok (47 tokens/s)
* Trying batch size 8... Ok (89 tokens/s)
* Trying batch size 16... Failed (CUDA out of memory. Tried to allocate 768.00 MiB. GPU 1 has a total capacity of 31.84 GiB of which 0 bytes is free. Of the allocated memory 28.09 GiB is allocated by
PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid
fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables))
* Chosen batch size: 8

Checking for common response prefix...
* None found

Loading good evaluation prompts from mlabonne/harmless_alpaca...
* 100 prompts loaded
* Obtaining first-token probability distributions...

Loading bad evaluation prompts from mlabonne/harmful_behaviors...
* 100 prompts loaded
* Counting model refusals...
* Initial refusals: 57/100

Obtaining module I/O for good prompts...
Obtaining module I/O for bad prompts...

Running trial 1 of 200...
* Parameters:
  * start_layer_index = 0
  * end_layer_index = 27
  * preserve_good_behavior_weight = 0.9375
  * steer_bad_behavior_weight = 0.0097
  * overcorrect_relative_weight = 0.7312
  * neighbor_count = 10
* Reloading model...
Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 751/751 [00:07<00:00, 102.26it/s]
* Abliterating (Arbitrary-Rank Ablation)...
[W 2026-03-15 11:33:19,747] Trial 0 failed with parameters: {'start_layer_index': 0, 'end_layer_index': 27, 'preserve_good_behavior_weight': 0.9374664708416186, 'steer_bad_behavior_weight': 0.009728694745534946, 'overcorrect_relative_weight': 0.731168614452757, 'neighbor_count': 10} because of the following error: ZeroDivisionError('float division by zero').
Traceback (most recent call last):
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\_optimize.py", line 206, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\me\Docs\IA\heretic\src\heretic\main.py", line 617, in objective_wrapper
    return objective(trial)
  File "C:\Users\me\Docs\IA\heretic\src\heretic\main.py", line 589, in objective
    model.ara_abliterate(good_module_io, bad_module_io, ara_parameters)
    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\me\Docs\IA\heretic\src\heretic\model.py", line 631, in ara_abliterate
    loss = optimizer.step(closure)
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\optimizer.py", line 526, in wrapper
    out = func(*args, **kwargs)
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\_dynamo\eval_frame.py", line 1181, in _fn
    return fn(*args, **kwargs)
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\utils\_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py", line 478, in step
    loss, flat_grad, t, ls_func_evals = _strong_wolfe(
                                        ~~~~~~~~~~~~~^
        obj_func,
        ^^^^^^^^^
    ...<6 lines>...
        max_ls=max_eval - current_evals,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py", line 114, in _strong_wolfe
    t = _cubic_interpolate(
        # pyrefly: ignore [index-error]
    ...<11 lines>...
        bracket_gtd[1],
    )
  File "C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py", line 27, in _cubic_interpolate
    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
                   ~~~~~~~~~~~~~~^~~~~~~~~~~
ZeroDivisionError: float division by zero
[W 2026-03-15 11:33:19,750] Trial 0 failed with value None.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in _run_module_as_main:198                                                                       │
│ in _run_code:88                                                                                  │
│                                                                                                  │
│ in <module>:5                                                                                    │
│                                                                                                  │
│   2 from heretic.main import main                                                                │
│   3 if __name__ == '__main__':                                                                   │
│   4 │   sys.argv[0] = sys.argv[0].removesuffix('.exe')                                           │
│ ❱ 5 │   sys.exit(main())                                                                         │
│   6                                                                                              │
│                                                                                                  │
│ C:\Users\me\Docs\IA\heretic\src\heretic\main.py:977 in main                              │
│                                                                                                  │
│   974 │   install()                                                                              │
│   975 │                                                                                          │
│   976 │   try:                                                                                   │
│ ❱ 977 │   │   run()                                                                              │
│   978 │   except BaseException as error:                                                         │
│   979 │   │   # Transformers appears to handle KeyboardInterrupt (or BaseException)              │
│   980 │   │   # internally in some places, which can re-raise a different error in the handler   │
│                                                                                                  │
│ C:\Users\me\Docs\IA\heretic\src\heretic\main.py:648 in run                               │
│                                                                                                  │
│   645 │   │   print("Resuming existing study.")                                                  │
│   646 │                                                                                          │
│   647 │   try:                                                                                   │
│ ❱ 648 │   │   study.optimize(                                                                    │
│   649 │   │   │   objective_wrapper,                                                             │
│   650 │   │   │   n_trials=settings.n_trials - count_completed_trials(),                         │
│   651 │   │   )                                                                                  │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\study.py:490 in       │
│ optimize                                                                                         │
│                                                                                                  │
│    487 │   │   │   RuntimeError:                                                                 │
│    488 │   │   │   │   If nested invocation of this method occurs.                               │
│    489 │   │   """                                                                               │
│ ❱  490 │   │   _optimize(                                                                        │
│    491 │   │   │   study=self,                                                                   │
│    492 │   │   │   func=func,                                                                    │
│    493 │   │   │   n_trials=n_trials,                                                            │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\_optimize.py:68 in    │
│ _optimize                                                                                        │
│                                                                                                  │
│    65 │                                                                                          │
│    66 │   try:                                                                                   │
│    67 │   │   if n_jobs == 1:                                                                    │
│ ❱  68 │   │   │   _optimize_sequential(                                                          │
│    69 │   │   │   │   study,                                                                     │
│    70 │   │   │   │   func,                                                                      │
│    71 │   │   │   │   n_trials,                                                                  │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\_optimize.py:165 in   │
│ _optimize_sequential                                                                             │
│                                                                                                  │
│   162 │   │   │   │   break                                                                      │
│   163 │   │                                                                                      │
│   164 │   │   try:                                                                               │
│ ❱ 165 │   │   │   frozen_trial_id = _run_trial(study, func, catch)                               │
│   166 │   │   finally:                                                                           │
│   167 │   │   │   # The following line mitigates memory problems that can be occurred in some    │
│   168 │   │   │   # environments (e.g., services that use computing containers such as GitHub    │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\_optimize.py:263 in   │
│ _run_trial                                                                                       │
│                                                                                                  │
│   260 │   │   and func_err is not None                                                           │
│   261 │   │   and not isinstance(func_err, catch)                                                │
│   262 │   ):                                                                                     │
│ ❱ 263 │   │   raise func_err                                                                     │
│   264 │   return trial._trial_id                                                                 │
│   265                                                                                            │
│   266                                                                                            │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\optuna\study\_optimize.py:206 in   │
│ _run_trial                                                                                       │
│                                                                                                  │
│   203 │                                                                                          │
│   204 │   with get_heartbeat_thread(trial._trial_id, study._storage):                            │
│   205 │   │   try:                                                                               │
│ ❱ 206 │   │   │   value_or_values = func(trial)                                                  │
│   207 │   │   except exceptions.TrialPruned as e:                                                │
│   208 │   │   │   # TODO(mamu): Handle multi-objective cases.                                    │
│   209 │   │   │   state = TrialState.PRUNED                                                      │
│                                                                                                  │
│ C:\Users\me\Docs\IA\heretic\src\heretic\main.py:617 in objective_wrapper                 │
│                                                                                                  │
│   614 │                                                                                          │
│   615 │   def objective_wrapper(trial: Trial) -> tuple[float, float]:                            │
│   616 │   │   try:                                                                               │
│ ❱ 617 │   │   │   return objective(trial)                                                        │
│   618 │   │   except KeyboardInterrupt:                                                          │
│   619 │   │   │   # Stop the study gracefully on Ctrl+C.                                         │
│   620 │   │   │   trial.study.stop()                                                             │
│                                                                                                  │
│ C:\Users\me\Docs\IA\heretic\src\heretic\main.py:589 in objective                         │
│                                                                                                  │
│   586 │   │   │   print("* Reloading model...")                                                  │
│   587 │   │   │   model.reset_model()                                                            │
│   588 │   │   │   print("* Abliterating (Arbitrary-Rank Ablation)...")                           │
│ ❱ 589 │   │   │   model.ara_abliterate(good_module_io, bad_module_io, ara_parameters)            │
│   590 │   │   else:                                                                              │
│   591 │   │   │   print("* Resetting model...")                                                  │
│   592 │   │   │   model.reset_model()                                                            │
│                                                                                                  │
│ C:\Users\me\Docs\IA\heretic\src\heretic\model.py:631 in ara_abliterate                   │
│                                                                                                  │
│   628 │   │   │   │   │                                                                          │
│   629 │   │   │   │   │   # Convergence usually happens within 2-3 steps, so this is more than   │
│   630 │   │   │   │   │   for step in range(5):                                                  │
│ ❱ 631 │   │   │   │   │   │   loss = optimizer.step(closure)                                     │
│   632 │   │   │   │   │   │   # print(                                                           │
│   633 │   │   │   │   │   │   #    f"\\[{layer_index}/{component}/{module_index}] Step: {step}   │
│   634 │   │   │   │   │   │   # )                                                                │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\optimizer.py:526 in    │
│ wrapper                                                                                          │
│                                                                                                  │
│    523 │   │   │   │   │   │   │   )                                                             │
│    524 │   │   │   │                                                                             │
│    525 │   │   │   │   # pyrefly: ignore [invalid-param-spec]                                    │
│ ❱  526 │   │   │   │   out = func(*args, **kwargs)                                               │
│    527 │   │   │   │   self._optimizer_step_code()                                               │
│    528 │   │   │   │                                                                             │
│    529 │   │   │   │   # call optimizer step post hooks                                          │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\_dynamo\eval_frame.py:1181   │
│ in _fn                                                                                           │
│                                                                                                  │
│   1178 │   │   │   │   │   │   │   }                                                             │
│   1179 │   │   │   │   │   │   ):                                                                │
│   1180 │   │   │   │   │   │   │   return fn(*args, **kwargs)                                    │
│ ❱ 1181 │   │   │   │   │   return fn(*args, **kwargs)                                            │
│   1182 │   │   │   │   finally:                                                                  │
│   1183 │   │   │   │   │   set_eval_frame(None)                                                  │
│   1184 │   │   │   finally:                                                                      │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\utils\_contextlib.py:124 in  │
│ decorate_context                                                                                 │
│                                                                                                  │
│   121 │   def decorate_context(*args, **kwargs):                                                 │
│   122 │   │   # pyrefly: ignore [bad-context-manager]                                            │
│   123 │   │   with ctx_factory():                                                                │
│ ❱ 124 │   │   │   return func(*args, **kwargs)                                                   │
│   125 │                                                                                          │
│   126 │   return decorate_context                                                                │
│   127                                                                                            │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py:478 in step   │
│                                                                                                  │
│   475 │   │   │   │   │   def obj_func(x, t, d):                                                 │
│   476 │   │   │   │   │   │   return self._directional_evaluate(closure, x, t, d)                │
│   477 │   │   │   │   │                                                                          │
│ ❱ 478 │   │   │   │   │   loss, flat_grad, t, ls_func_evals = _strong_wolfe(                     │
│   479 │   │   │   │   │   │   obj_func,                                                          │
│   480 │   │   │   │   │   │   x_init,                                                            │
│   481 │   │   │   │   │   │   t,                                                                 │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py:114 in        │
│ _strong_wolfe                                                                                    │
│                                                                                                  │
│   111 │   │   │   break                                                                          │
│   112 │   │                                                                                      │
│   113 │   │   # compute new trial value                                                          │
│ ❱ 114 │   │   t = _cubic_interpolate(                                                            │
│   115 │   │   │   # pyrefly: ignore [index-error]                                                │
│   116 │   │   │   # pyrefly: ignore [unbound-name]                                               │
│   117 │   │   │   bracket[0],                                                                    │
│                                                                                                  │
│ C:\Users\me\miniconda3\envs\heretic\Lib\site-packages\torch\optim\lbfgs.py:27 in         │
│ _cubic_interpolate                                                                               │
│                                                                                                  │
│    24 │   #   d2 = sqrt(d1^2 - g1*g2);                                                           │
│    25 │   #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));                        │
│    26 │   #   t_new = min(max(min_pos,xmin_bound),xmax_bound);                                   │
│ ❱  27 │   d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)                                               │
│    28 │   d2_square = d1**2 - g1 * g2                                                            │
│    29 │   if d2_square >= 0:                                                                     │
│    30 │   │   d2 = d2_square.sqrt()                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ZeroDivisionError: float division by zero

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 15, 2026

@EugeoSynthesisThirtyTwo

Thanks for letting me know. The objective has small gradient discontinuities because of the top-k neighbor selection, which might be the root cause of this problem.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 16, 2026

UGI results for the first ARA model are in. They're promising, though not quite as good as I had hoped:

ara_ugi

 

The difference between Derestricted, SOMA, and ARA is basically noise considering the range of the rest of the leaderboard, but MuXodious/gpt-oss-20b-RichardErkhov-heresy is clearly much better than any of them.

So ARA appears to work well, especially considering that it operates completely differently from any other method on that list, but it's not quite where I want it to be yet, which is at the number 1 spot. I'm going to try adding magnitude preservation to the optimizer (mimicking MPOA, which has been suggested by several people), and see whether it improves the result.

Interestingly, no uncensoring method so far has even come close to preserving gpt-oss-20b's NatInt score of 27.18.

@erm14254
Copy link
Copy Markdown
Contributor

I'm going to try adding magnitude preservation to the optimizer (mimicking MPOA, which has been suggested by several >people), and see whether it improves the result.

More ARA models have been rated in the UGI, there is a pattern where those alliterated by MPOA do in general retain more quality than ARA even at higher KL divergence.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 19, 2026

More ARA models have been rated in the UGI, there is a pattern where those alliterated by MPOA do in general retain more quality than ARA even at higher KL divergence.

Not sure if this is accurate. ArliAI/gpt-oss-20b-Derestricted (which uses MPOA) actually has a substantially worse NatInt score than p-e-w/gpt-oss-20b-heretic-ara-v3.

I've been researching how we can predict model quality more reliably than using the KLD alone. Please see #236 for initial results.

@erm14254
Copy link
Copy Markdown
Contributor

erm14254 commented Mar 19, 2026

Not sure if this is accurate. ArliAI/gpt-oss-20b-Derestricted (which uses MPOA) actually has a substantially worse NatInt score than p-e-w/gpt-oss-20b-heretic-ara-v3.

I said in general it's not a clear-cut rule and I look at more than just Native Intelligence, plus if you would notice some abliterated models got quite a bit higher Native Intelligence post abliteration than the vanilla baseline pre-abliteration.

I've been researching how we can predict model quality more reliably than using the KLD alone. Please see #236 for initial results.

Thank, I'll look into it.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 19, 2026

if you would notice some abliterated models got quite a bit higher Native Intelligence post abliteration than the vanilla baseline pre-abliteration.

This has never been achieved with gpt-oss-20b (very far from it, actually), which is why it is my model of choice for experiments. Whatever works there will likely work elsewhere as well.

@p-e-w
Copy link
Copy Markdown
Owner Author

p-e-w commented Mar 31, 2026

The ARA branch now supports row-norm preservation during optimization using a reparameterization constraint, as suggested by @spikymoth and others. It will be interesting to see whether this improves benchmark scores.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants