Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from olive.passes import Pass
from olive.passes.olive_pass import PassConfigParam
from olive.passes.pass_config import BasePassConfig
from olive.search.search_parameter import Boolean, Categorical

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,11 +84,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"int4_block_size": PassConfigParam(
type_=ModelBuilder.BlockSize,
required=False,
search_defaults=Categorical(
[ModelBuilder.BlockSize.B32, ModelBuilder.BlockSize.B64, ModelBuilder.BlockSize.B128]
),
description="Specify the block_size for int4 quantization. Acceptable values: 16/32/64/128/256.",
),
"int4_is_symmetric": PassConfigParam(
type_=bool,
required=False,
search_defaults=Boolean(),
description="Specify whether symmetric or asymmetric INT4 quantization needs to be used.",
),
"int4_op_types_to_quantize": PassConfigParam(
Expand All @@ -106,6 +111,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"int4_algo_config": PassConfigParam(
type_=str,
required=False,
search_defaults=Categorical(
[
"default",
"rtn",
"k_quant_mixed",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need enum str for this as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only a single use case for it. We don't really do anything with it except pass it down to the model builder.

"k_quant_last",
]
),
description="Specify the INT4 quantization algorithm to use in GenAI Model Builder",
),
"use_qdq": PassConfigParam(
Expand Down
10 changes: 9 additions & 1 deletion olive/passes/pytorch/selective_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam
from olive.passes.pytorch.train_utils import get_calibration_dataset, kl_div_loss, load_hf_base_model
from olive.search.search_parameter import Categorical

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
Expand Down Expand Up @@ -65,7 +66,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
return {
"algorithm": PassConfigParam(
type_=SelectiveMixedPrecision.Algorithm,
required=True,
required=False,
search_defaults=Categorical(
[
SelectiveMixedPrecision.Algorithm.K_QUANT_DOWN,
SelectiveMixedPrecision.Algorithm.K_QUANT_MIXED,
SelectiveMixedPrecision.Algorithm.K_QUANT_LAST,
]
),
description="The algorithm to use for mixed precision.",
),
"bits": PassConfigParam(
Expand Down
Loading