Skip to content

Commit

Permalink
fix quark fp8 loading
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty-amd committed Jan 31, 2025
1 parent 6852819 commit a38d96a
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ def process_weights_after_loading(self, layer) -> None:
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)

if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)

max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)

layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

Expand Down

0 comments on commit a38d96a

Please sign in to comment.