Skip to content

remove str option for quantization config in torchao#13291

Open
howardzhang-cv wants to merge 4 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test
Open

remove str option for quantization config in torchao#13291
howardzhang-cv wants to merge 4 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test

Conversation

@howardzhang-cv
Copy link
Contributor

What does this PR do?

Remove the deprecated string-based quant_type path from TorchAoConfig, requiring AOBaseConfig instances instead.

  • TorchAoConfig.init now only accepts AOBaseConfig subclass instances (e.g. Int8WeightOnlyConfig()) and raises TypeError for strings
  • Deleted ~200 lines of dead code: _get_torchao_quant_type_to_method, _is_xpu_or_cuda_capability_atleast_8_9, TorchAoJSONEncoder, and all string-parsing branches in post_init, to_dict, from_dict, get_apply_tensor_subclass
  • Simplified torchao_quantizer.py: removed string-based branches in update_torch_dtype, adjust_target_dtype, get_cuda_warm_up_factor; fixed is_trainable which would crash on AOBaseConfig objects
  • Converted all test cases from string quant types to their AOBaseConfig equivalents; removed test_floatx_quantization (no replacement for fpx_weight_only)
  • Updated docs to show only AOBaseConfig-based usage

Testing

python -m pytest tests/quantization/torchao/test_torchao.py -xvs

Who can review?

@sayakpaul

Comment on lines +217 to +227
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16

elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
else:
# Default to int8
return torch.int8

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit fragile, I think it's from transformers originally, not sure if this is still needed in transformers though, it might have been refactored after 5.0 update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry might be a dumb question, what part are you referring to that's from transformers originally? Is it the entire adjust_target_dtype function?

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link

@jerryzh168 jerryzh168 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR: I feel we should just have a single assertion for torchao to be a relatively recent version (e.g. 0.15) and remove all these version checks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah I was going to ask you about that as well. There's a couple version checks scattered around right now. Would be cleaner to just remove all of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to wait to do this in a separate PR? I changed and set it to 0.9.0 because that's when AOBaseConfig was supported. Moving to 0.15.0 might be cleaner in a separate PR in case we need to revert for whatever reason?

@howardzhang-cv howardzhang-cv marked this pull request as ready for review March 19, 2026 21:43
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for starting this work!

I think we can merge this fairly soon.

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Comment on lines +478 to +479
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes cool!

Comment on lines +488 to +490
# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be really nice! I think we should also provide a reference link to the TorchAO docs to remind ourselves what that granularity means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Mar 20, 2026

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. I will run the tests on my end to see if we didn't miss anything.

We need to update the test suite here as well:

class TorchAoConfigMixin:


Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.

The quantization methods supported are as follows:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a note about the different Config classes that are supported here? And linking to the TorchAO docs?

logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not implement an equivalent of

        if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
            if torch_dtype is not None and torch_dtype != torch.bfloat16:
                logger.warning(
                    f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
                    f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
                )

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants