"""Dequantize mtp.fc from GPTQ INT4 back to bf16 so vLLM's MTP loader picks it up."""
import json, shutil
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
BASE = Path("Qwen3.6-27B-int4-AutoRound")
EXTRA = BASE / "model_extra_tensors.safetensors"
INDEX = BASE / "model.safetensors.index.json"
tensors = {}
with safe_open(EXTRA, framework="pt") as f:
meta = f.metadata() or {}
for k in f.keys():
tensors[k] = f.get_tensor(k)
qw = tensors["mtp.fc.qweight"]
qz = tensors["mtp.fc.qzeros"]
sc = tensors["mtp.fc.scales"]
in_features = qw.shape[0] * 8
out_features = qw.shape[1]
group_size = 128
num_groups = in_features // group_size
def unpack_int32_4bit(packed, axis, factor=8):
dev = packed.device
shifts = torch.arange(0, 32, 4, device=dev, dtype=torch.int32)
expanded = (packed.unsqueeze(axis + 1) >> shifts.view([8 if i == axis + 1 else 1 for i in range(packed.ndim + 1)])) & 0xF
new_shape = list(packed.shape); new_shape[axis] *= factor
return expanded.reshape(new_shape).to(torch.int8)
w_int = unpack_int32_4bit(qw, axis=0)
z_int = unpack_int32_4bit(qz, axis=1)
w_grouped = w_int.view(num_groups, group_size, out_features).to(torch.float32)
w_fp32 = (w_grouped - z_int.unsqueeze(1).to(torch.float32)) * sc.unsqueeze(1).to(torch.float32)
w_final = w_fp32.view(in_features, out_features).t().contiguous().to(torch.bfloat16)
for k in ("mtp.fc.qweight", "mtp.fc.qzeros", "mtp.fc.scales"):
del tensors[k]
tensors["mtp.fc.weight"] = w_final
save_file(tensors, str(EXTRA), metadata=meta)
idx = json.loads(INDEX.read_text())
for k in ("mtp.fc.qweight", "mtp.fc.qzeros", "mtp.fc.scales"):
idx["weight_map"].pop(k, None)
idx["weight_map"]["mtp.fc.weight"] = EXTRA.name
from collections import defaultdict
shard_sizes = defaultdict(int)
for sf in set(idx["weight_map"].values()):
with safe_open(BASE / sf, framework="pt") as f:
for k in f.keys():
t = f.get_tensor(k)
shard_sizes[sf] += t.numel() * t.element_size()
idx["metadata"]["total_size"] = sum(shard_sizes.values())
INDEX.write_text(json.dumps(idx, indent=2))