diff --git a/tests/ops/test_matmul.py b/tests/ops/test_matmul.py index 96653d1..67d85e7 100644 --- a/tests/ops/test_matmul.py +++ b/tests/ops/test_matmul.py @@ -24,15 +24,17 @@ def run_problem(self, m: int, n: int, k: int, groupsize=-1, dev="cuda"): print(f"Out of memory, skipping m={m}, n={n}, k={k}") def test_tiny(self): - self.run_problem(21504 * 2, 4096, 21504 * 2, groupsize=-1) - self.run_problem(256, 16, 256, groupsize=-1) - self.run_problem(256, 16, 512, groupsize=-1) + # self.run_problem(21504 * 2, 4096, 21504 * 2, groupsize=-1) + # self.run_problem(256, 16, 256, groupsize=-1) + # self.run_problem(256, 16, 512, groupsize=-1) + self.run_problem(5504, 5504, 5504, groupsize=-1) + - def test_llama(self): - bsz = 16 - for _, layers in llama_shapes.items(): - for layer in layers: - self.run_problem(layer[1], bsz, layer[0]) + # def test_llama(self): + # bsz = 16 + # for _, layers in llama_shapes.items(): + # for layer in layers: + # self.run_problem(layer[1], bsz, layer[0]) if __name__ == "__main__": diff --git a/tests/ops/test_sbmm.py b/tests/ops/test_sbmm.py index f84fb20..dc45027 100644 --- a/tests/ops/test_sbmm.py +++ b/tests/ops/test_sbmm.py @@ -31,7 +31,7 @@ def run_problem( indices = torch.sort(indices)[0] x = torch.randn((nr, k), dtype=torch.float16, device=dev) weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( - nr, m, k, groupsize=groupsize, device=dev + nm, m, k, groupsize=groupsize, device=dev ) fp16_output = sbmm_16bit_forloop(weight_ref, x, indices, base_weight=None) forloop_output = sbmm_4bit_2_4_forloop( @@ -65,8 +65,9 @@ def run_problem( torch.cuda.empty_cache() def test_tiny(self): - self.run_problem("uniform", 10, 5, 256, 256) + self.run_problem("uniform", 10, 5, 256, 256) self.run_problem("zipf:1.5", 128, 2, 4096, 12288) + # def test_llama(self): # nrs = [16, 32, 64, 128, 256] # nms = [[2,4,8,16], [2,4,8,16,32], [2,4,8,16,32,64], [2,4,8,16,32,64,128], [2,4,8,16,32,64,128,256]] diff --git a/tests/ops/test_sbmm_tp.py b/tests/ops/test_sbmm_tp.py new file mode 100644 index 0000000..bd9061e --- /dev/null +++ b/tests/ops/test_sbmm_tp.py @@ -0,0 +1,86 @@ +import torch +import unittest +from triteia.python.ops import ( + sbmm_16bit_forloop, + sbmm_4bit_2_4_forloop, + sbmm_4bit_2_4_native, + sbmm_4bit_2_4_multilaunch, +) +from triteia.python.configs.models.llama import llama_shapes +from triteia.python.ops.utils.generator import generate_model_distribution +from triteia.python.ops import gen_batched_sparse_quant4_NT + + +class TestSBMMOp(unittest.TestCase): + def run_problem_column( + self, + distribution: str, + nr: int, + nm: int, + m: int, + k: int, + tp_size: int, + with_base_weight=False, + groupsize=-1, + dev="cuda", + ): + try: + print( + f"Running sbmm problem with nr={nr}, nm={nm}, m={m}, k={k}, distribution={distribution}" + ) + indices = generate_model_distribution(distribution, nr, nm) + indices = torch.sort(indices)[0] + x = torch.randn((nr, k), dtype=torch.float16, device=dev) + ref_weights, qweights, scales, metas = [], [], [], [] + ref_fp16_outputs = [] + outputs = [] + for i in range(tp_size): + weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( + nm, m, k, groupsize=groupsize, device=dev + ) + ref_weights.append(weight_ref) + qweights.append(qweight) + scales.append(scale) + metas.append(meta) + + fp16_partial_output = sbmm_16bit_forloop(weight_ref, x, indices, base_weight=None) + native_partial_output = sbmm_4bit_2_4_native( + qweight, x, meta, scale, indices, base_weight=None + ) + ref_fp16_outputs.append(fp16_partial_output) + outputs.append(native_partial_output) + + ref_fp16_final_outputs = torch.cat(ref_fp16_outputs, dim=1) + final_outputs = torch.cat(outputs, dim=1) + + stacked_fp16_weights = torch.cat(ref_weights, dim=2) + stacked_qweights = torch.cat(qweights, dim=2) + stacked_scales = torch.cat(scales, dim=2) + stacked_metas = torch.cat(metas, dim=1) + + stacked_fp16_output = sbmm_16bit_forloop(stacked_fp16_weights, x, indices, base_weight=None) + stacked_native_output = sbmm_4bit_2_4_native( + stacked_qweights, x, stacked_metas, stacked_scales, indices, base_weight=None + ) + self.assertLess( + torch.mean(torch.abs(final_outputs - ref_fp16_final_outputs)) + / torch.mean(torch.abs(ref_fp16_final_outputs)), + 0.002, + ) + self.assertLess( + torch.mean(torch.abs(stacked_fp16_output - ref_fp16_final_outputs)) + / torch.mean(torch.abs(ref_fp16_final_outputs)), + 0.002, + ) + + except torch.cuda.OutOfMemoryError as e: + print(f"Out of memory, skipping nr={nr}, nm={nm}, m={m}, k={k}") + finally: + torch.cuda.empty_cache() + + def test_tiny(self): + self.run_problem_column("uniform", 10, 5, 256, 256, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/triteia/python/nn/linear.py b/triteia/python/nn/linear.py index 4a38041..9f338fc 100644 --- a/triteia/python/nn/linear.py +++ b/triteia/python/nn/linear.py @@ -98,8 +98,8 @@ def pack(self, weight, scales, trans=False): maxq = 2**4 - 1 s = scales w = weight - if self.groupsize != self.k: - w = w.reshape((-1, self.groupsize, self.n)) + if self.groupsize != self.infeatures: + w = w.reshape((-1, self.groupsize, self.outfeatures)) w = w.permute(1, 0, 2) w = w.reshape((self.groupsize, -1)) s = s.reshape((1, -1)) @@ -108,22 +108,22 @@ def pack(self, weight, scales, trans=False): w = torch.round(w / s).int() w += (maxq + 1) // 2 w = torch.clamp(w, 0, maxq) - if self.groupsize != self.k: - w = w.reshape((self.groupsize, -1, self.n)) + if self.groupsize != self.infeatures: + w = w.reshape((self.groupsize, -1, self.outfeatures)) w = w.permute(1, 0, 2) - w = w.reshape((self.k, self.n)).contiguous() + w = w.reshape((self.infeatures, self.outfeatures)).contiguous() s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] w = mask * w.T w, meta = sparse_semi_structured_from_dense_cutlass(w) w = w.t() - self.k = self.k // 2 + self.infeatures = self.infeatures // 2 self.groupsize = self.groupsize // 2 - s = s.reshape((-1, self.n)).contiguous() - w = w.reshape((self.k // tile, tile, self.n // tile, tile)) + s = s.reshape((-1, self.outfeatures)).contiguous() + w = w.reshape((self.infeatures // tile, tile, self.outfeatures // tile, tile)) w = w.permute((0, 2, 1, 3)) - w = w.reshape((self.k // tile, self.n * tile)) + w = w.reshape((self.infeatures // tile, self.outfeatures * tile)) res = w res = res.reshape((-1, perm.numel()))[:, perm].reshape(res.shape) q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) diff --git a/triteia/python/ops/utils/generator.py b/triteia/python/ops/utils/generator.py index a8c08fe..628484e 100644 --- a/triteia/python/ops/utils/generator.py +++ b/triteia/python/ops/utils/generator.py @@ -36,11 +36,9 @@ def reshape(w): s = s.reshape((-1, m)).contiguous() linear = nn.Linear(k, m) linear.weight.data = ref - layer = sparse_low_precision_linear(m, k, groupsize=groupsize) + layer = sparse_low_precision_linear(k, m, groupsize=-1) if groupsize == -1: groupsize = k - layer.k = k - layer.n = m layer.groupsize = groupsize layer.B = torch.empty((k_sp // 16, m * 16 // 8), dtype=torch.int, device=device) layer.meta = torch.empty((m, k // 16), dtype=torch.int16, device=device) diff --git a/triteia/python/utils/compressor.py b/triteia/python/utils/compressor.py new file mode 100644 index 0000000..77a8249 --- /dev/null +++ b/triteia/python/utils/compressor.py @@ -0,0 +1,121 @@ +import torch +import cupy as cp +from typing import Dict +from torch.utils.dlpack import to_dlpack, from_dlpack +from tqdm import tqdm +try: + import kvikio + from kvikio.nvcomp import LZ4Manager + from kvikio.nvcomp import SnappyManager + from kvikio.nvcomp import BitcompManager + from kvikio.nvcomp import GdeflateManager + from kvikio.nvcomp import CascadedManager +except ImportError: + raise ImportError( + "Please install kvikio to use the LosslessCompressor class. " + "Check out `https://github.com/rapidsai/kvikio`." + ) + +dtype_maps = { + "int8": torch.int8, + "fp16": torch.float16, + "fp32": torch.float32, + "int32": torch.int32, +} + +cp_dtype_maps = { + "int8": cp.int8, + "fp16": cp.float16, + "fp32": cp.float32, + "int32": cp.int32, +} + + +class LosslessCompressor: + + def __init__(self, algorithm: str = "gdeflate", device_id: int = 0) -> None: + if algorithm == "gdeflate": + self.comp_manager = GdeflateManager(device_id=device_id) + elif algorithm == "lz4": + self.comp_manager = LZ4Manager(device_id=device_id) + elif algorithm == "snappy": + self.comp_manager = SnappyManager(device_id=device_id) + elif algorithm == "bitcomp": + self.comp_manager = BitcompManager(device_id=device_id) + elif algorithm == "cascaded": + self.comp_manager = CascadedManager(device_id=device_id) + else: + raise ValueError( + f"Unsupported algorithm: {algorithm}, supported algorithms: ['gdeflate', 'lz4', 'snappy', 'bitcomp', 'cascaded']" + ) + + def compress_tensor(self, tensor: torch.Tensor): + tensor.requires_grad_(False) + tensor_shape = tensor.shape + if not tensor.is_cuda: + tensor = tensor.cuda() + to_compress_tensor = cp.from_dlpack(to_dlpack(tensor)) + # logger.debug(f"compressiong dtype {tensor.dtype}") + if tensor.dtype == torch.int8: + dtype = "int8" + self.comp_manager.input_type = cp.int8 + elif tensor.dtype == torch.float16: + dtype = "fp16" + self.comp_manager.input_type = cp.float16 + elif tensor.dtype == torch.int32: + dtype = "int32" + self.comp_manager.input_type = cp.int32 + elif tensor.dtype == torch.float32: + dtype = "fp32" + self.comp_manager.input_type = cp.float32 + else: + raise ValueError(f"Unsupported dtype: {tensor.dtype}") + compressed_tensor = self.comp_manager.compress(to_compress_tensor) + return cp.asnumpy(compressed_tensor), tensor_shape, dtype + + def decompress_tensor( + self, + compressed_tensor: cp.array, + tensor_shape: tuple, + dtype="fp16", + target_device="cuda:0", + ): + self.comp_manager.input_type = cp_dtype_maps[dtype] + decompressed_tensor = self.comp_manager.decompress(compressed_tensor) + torch_tensor = torch.reshape( + from_dlpack(decompressed_tensor.toDlpack()), tensor_shape + ) + return torch_tensor.to(torch.device(target_device)) + + def compress_state_dict(self, state_dict: Dict[str, torch.Tensor]): + tensors = {} + tensors_shape = {} + tensors_dtype = {} + for key in state_dict: + tensors[key], tensors_shape[key], tensors_dtype[key] = self.compress_tensor( + state_dict[key] + ) + return tensors, tensors_shape, tensors_dtype + + def decompress_state_dict( + self, + compressed_state_dict: Dict[str, cp.array], + tensor_shapes: Dict[str, tuple], + tensor_dtypes: Dict[str, str] = None, + use_bfloat16: bool = False, + target_device: str = "cuda:0", + ): + with torch.no_grad(): + tensors = {} + for key in tqdm(compressed_state_dict.keys()): + decompressed = self.decompress_tensor( + compressed_state_dict[key], + tensor_shapes[key], + tensor_dtypes[key], + target_device, + ) + if use_bfloat16: + tensors[key] = decompressed.bfloat16() + else: + tensors[key] = decompressed + return tensors \ No newline at end of file diff --git a/triteia/python/utils/io.py b/triteia/python/utils/io.py new file mode 100644 index 0000000..dc1cb73 --- /dev/null +++ b/triteia/python/utils/io.py @@ -0,0 +1,19 @@ +import safetensors as st +from safetensors.torch import save_file + +def save_tensors(tensors, path): + for key in tensors.keys(): + tensors[key] = tensors[key].contiguous() + save_file(tensors, path) + +def read_tensors(path, prefix=None, device='cpu'): + tensors = {} + with st.safe_open(path, framework="pt", device=device) as f: + for key in f.keys(): + if prefix is None: + tensors[key] = f.get_tensor(key) + else: + if key.startswith(prefix): + module_name = key.removeprefix(prefix + ".") + tensors[module_name] = f.get_tensor(key) + return tensors \ No newline at end of file diff --git a/triteia/python/utils/quant_utils.py b/triteia/python/utils/quant_utils.py new file mode 100644 index 0000000..9456bf9 --- /dev/null +++ b/triteia/python/utils/quant_utils.py @@ -0,0 +1,74 @@ +# adapted from https://github.com/IST-DASLab/marlin/blob/2e87035acf1b117aaf2c840c32b6a2b0a6c6ca4a/conversion/convert.py +import torch + +@torch.no_grad() +def unpack_4bit_to_32bit_signed(qweight, qzeros): + # Unpack 4-bit values and interpret them as signed integers + unpacked_weights = torch.zeros( + (qweight.shape[0]*8, qweight.shape[1]), + dtype=torch.int8, + device=qweight.device, + requires_grad=False + ) + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1]*8), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False + ) + + + for row in range(unpacked_weights.shape[0]): + i = row % 8 + unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF + + for col in range(unpacked_zeros.shape[1]): + i = col % 8 + unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF + + if not torch.all(unpacked_zeros == 7): + raise ValueError( + "Marlin kernel is compatible only with checkpoints using symmetric quantization." + "Found non-symmetric quantization for the weight" + ) + return unpacked_weights, unpacked_zeros + 1 + +@torch.no_grad() +def dequantize_weight(qweight, qzeros, scales): + unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) + group_size = unpacked_qweight.shape[0] // scales.shape[0] + scales = scales.repeat_interleave(group_size, dim=0) + unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) + unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales + return unpacked_qweight.T + +@torch.no_grad() +def gptq_unpack(bits, qweight, qzeros, scales, group_size=-1): + if group_size == -1: + group_size = qweight.shape[0] * 32 // bits + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0) + wf = wf.to(qweight.device) + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0), + ).to(torch.int16 if bits == 8 else torch.int8) + + zeros = zeros + 1 + zeros = torch.bitwise_and( + zeros, (2**bits) - 1 + ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important. + + zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) + + scales = scales + scales = scales.reshape(-1, 1, scales.shape[-1]) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), + wf.unsqueeze(-1), + ).to(torch.int16 if bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**bits) - 1) + weight = weight.reshape(-1, group_size, weight.shape[2]) + weight = scales * (weight - zeros) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + return weight \ No newline at end of file diff --git a/triteia/tools/converters/convert_deltazip.py b/triteia/tools/converters/convert_deltazip.py new file mode 100644 index 0000000..5087eb9 --- /dev/null +++ b/triteia/tools/converters/convert_deltazip.py @@ -0,0 +1,192 @@ +import json +import cupy as cp +from tqdm import tqdm +import safetensors as st +import torch, argparse +from triteia.python.utils.io import save_tensors +from triteia.python.utils.quant_utils import dequantize_weight +from triteia.python.utils.compressor import LosslessCompressor +from triteia.python.configs.models.llama import row_chunking_modules, uncompressed_row_chunking_modules, pack_modules +from triteia.python.nn.linear import sparse_low_precision_linear + +@torch.no_grad() +def torch_weight_to_sparse_marlin(weight, scale, tp_size=1, chunk_by="column"): + """ + Args: + weight: torch.Tensor of shape (in_features, out_features) + scale: torch.Tensor of shape (1, out_features) + tp_size: tensor parallelism size + chunk_by: "column" or "row" + """ + assert chunk_by in ["column", "row"], "chunk_by must be either 'column' or 'row'" + assert weight.dim() == 2, "weight must be a 2D tensor" + assert weight.size(0) % tp_size == 0, "out_features must be divisible by tp_size" + assert weight.size(1) == scale.size(1), "out_features of weight and scale must match" + + if not weight.is_contiguous(): + weight = weight.contiguous() + if not scale.is_contiguous(): + scale = scale.contiguous() + + qweights, scales,metas = [], [], [] + for i in range(tp_size): + if chunk_by == "column": + tp_weight = weight[ + :, + i * weight.size(1) // tp_size: (i + 1) * weight.size(1) // tp_size + ] + tp_scales = scale[ + :, + i * scale.size(1) // tp_size: (i + 1) * scale.size(1) // tp_size + ] + elif chunk_by == "row": + tp_weight = weight[ + i * weight.size(0) // tp_size: (i + 1) * weight.size(0) // tp_size, + : + ] + tp_scales = scale + layer = sparse_low_precision_linear( + infeatures=tp_weight.size(0), + outfeatures=tp_weight.size(1), + groupsize=-1 + ) + k, m = tp_weight.size(0), tp_weight.size(1) + k_sp = k // 2 + layer.groupsize = k + layer.B = torch.empty((k_sp // 16, m * 16 // 8), dtype=torch.int) + layer.meta = torch.empty((m, k // 16), dtype=torch.int16) + layer.s = torch.empty((k_sp // (k // 2), m), dtype=torch.half) + layer.pack(tp_weight, scales=tp_scales, trans=True) + qweights.append(layer.B) + scales.append(layer.s) + metas.append(layer.meta) + return qweights, scales, metas + +@torch.no_grad() +def convert_model(args, verbose=True): + DEV = "cuda:0" + + new_tensors = {} + tensors = {} + packed_tensors = {} + dequantized_tensors = {} + remaining_keys = [] + + with st.safe_open(args.ckpt, framework="torch", device="cuda:0") as f: + keys = f.keys() + remaining_keys = list(f.keys()) + metadata = f.metadata() + for key in keys: + tensors[key] = f.get_tensor(key) + if args.lossless: + tensors_dtypes = json.loads(metadata["dtype"]) + tensors_shapes = json.loads(metadata["shape"]) + + if args.lossless: + print(f"Decompressing from lossless format...") + with cp.cuda.Device(0): + for key in tensors.keys(): + tensors[key] = cp.array(tensors[key], copy=False) + lc = LosslessCompressor() + tensors = lc.decompress_state_dict( + tensors, + tensors_shapes, + tensors_dtypes, + use_bfloat16=False, + target_device="cuda:0", + ) + # infeatures, outfeatures + quantized_modules = [ + x.removesuffix(".qweight") for x in tensors.keys() if "qweight" in x + ] + pbar = tqdm(quantized_modules, position=0, leave=True) + print("Dequantizing weights...") + for module in pbar: + dequantized_weight = dequantize_weight( + tensors[module + ".qweight"], + tensors[module + ".qzeros"], + tensors[module + ".scales"], + ).to(torch.float16).t().cpu() + scales = tensors[module + ".scales"] + dequantized_tensors[module] = (dequantized_weight, scales) + remaining_keys.remove(module + ".qweight") + remaining_keys.remove(module + ".qzeros") + remaining_keys.remove(module + ".scales") + remaining_keys.remove(module + ".g_idx") + + # now start to pack weights together + pack_plan = {} + for module in quantized_modules: + if any([key in module for key in pack_modules.keys()]): + source_layer = module.rsplit(".", 2)[0] + source_module = module.replace(source_layer+".", "") + target_module = pack_modules[source_module] + target_idx = int(target_module.split(":")[1]) + target_module = source_layer + "." + target_module.split(":")[0] + if target_module not in pack_plan: + pack_plan[target_module] = [] + pack_plan[target_module].append((module, target_idx)) + + elif any([key in module for key in row_chunking_modules]): + qweights, scales, metas = torch_weight_to_sparse_marlin( + dequantized_tensors[module][0].to(DEV), + dequantized_tensors[module][1].to(DEV), + tp_size=args.tp_size, + chunk_by="row", + ) + for idx, (qweight, scales, meta) in enumerate(zip(qweights, scales, metas)): + new_tensors[module + f".{idx}.qweight"] = qweight + new_tensors[module + f".{idx}.scales"] = scales + new_tensors[module + f".{idx}.meta"] = meta + for key in pack_plan.keys(): + key_weights = [] + key_scales = [] + plan = sorted(pack_plan[key], key=lambda x: x[1]) + print(f"Plan for {key}: {plan}") + for module, idx in plan: + weight, scales = dequantized_tensors[module] + assert weight.shape[1] == scales.shape[1] + key_weights.append(weight) + key_scales.append(scales) + key_weights = torch.cat(key_weights, dim=1) + key_scales = torch.cat(key_scales, dim=1) + packed_tensors[key] = (key_weights, key_scales) + torch.cuda.synchronize() + del dequantized_tensors[module] + torch.cuda.empty_cache() + + qweights, scales, metas = torch_weight_to_sparse_marlin( + packed_tensors[key][0].to(DEV), + packed_tensors[key][1].to(DEV), + tp_size=args.tp_size, + chunk_by="column", + ) + for idx, (qweight, scales, meta) in enumerate(zip(qweights, scales, metas)): + new_tensors[key + f".{idx}.qweight"] = qweight + new_tensors[key + f".{idx}.scales"] = scales + new_tensors[key + f".{idx}.meta"] = meta + + # # now processing remaining keys + for module in remaining_keys: + if any([key in module for key in uncompressed_row_chunking_modules]): + weight = tensors[module] + module_name = module.removesuffix(".weight") + num_rows = weight.shape[0] + for i in range(args.tp_size): + tp_weight = weight[i * num_rows // args.tp_size: (i + 1) * num_rows // args.tp_size, :] + new_tensors[module_name + f".{i}.weight"] = tp_weight + + return new_tensors + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str) + parser.add_argument("--tp-size", type=int) + parser.add_argument("--save-path", type=str) + parser.add_argument("--lossless", action="store_true") + parser.add_argument("--pack", action="store_true") + args = parser.parse_args() + + print("Converting model...") + new_tensors = convert_model(args, verbose=True) + save_tensors(new_tensors, args.save_path) \ No newline at end of file diff --git a/triteia/tools/verify_weights.py b/triteia/tools/verify_weights.py new file mode 100644 index 0000000..d1e6c5c --- /dev/null +++ b/triteia/tools/verify_weights.py @@ -0,0 +1,78 @@ +import torch +from triteia.python.utils.io import read_tensors +from triteia.python.ops import matmul_4bit_2_4 +from triteia.python.ops.utils.generator import generate_model_distribution + +def check_tp_group_equal(weights, reference_weights): + modules = set() + tp_groups = set() + + for key in weights.keys(): + # separate by . + # last element - component, second last - tp id, others - module name + tp_group = key.split(".")[-2] + tp_groups.add(tp_group) + module_name = ".".join(key.split(".")[:-2]) + modules.add(module_name) + + for module in modules: + tp_groups_in_modules = max([int(key.split(".")[-2]) for key in weights.keys() if module in key]) + 1 + components_in_modules = [key.split(".")[-1] for key in weights.keys() if module in key] + for component in components_in_modules: + components_across_tp = [value for key, value in weights.items() if module in key and component in key] + # there should be at most tp_groups_in_modules tensors for each component + assert len(components_across_tp) == tp_groups_in_modules, f"Module {module} has {len(components_across_tp)} components for {component}" + # check if there are same tensors for each component + for i in range(1, len(components_across_tp)): + if torch.equal(components_across_tp[i-1], components_across_tp[i]): + print(f"Module {module} has same tensors for {component} in tp group {i-1} and {i}") + +def check_output(weights, reference_weights, module_name): + target_weight = {key: value for key, value in weights.items() if module_name in key} + reference_weight = {key: value for key, value in reference_weights.items() if module_name in key} + tp_groups = set() + for key in weights.keys(): + # separate by . + # last element - component, second last - tp id, others - module name + tp_group = key.split(".")[-2] + tp_groups.add(tp_group) + reference_qweight = reference_weights[f"{module_name}.0.qweight"] + reference_meta = reference_weights[f"{module_name}.0.meta"] + reference_scale = reference_weights[f"{module_name}.0.scales"] + + nr = 10 + x = torch.randn((nr, 32 * reference_qweight.size(0)), dtype=torch.float16, device='cuda') + reference_output = matmul_4bit_2_4(reference_qweight, x, reference_meta, reference_scale) + tp_outputs = [] + tp_groups = sorted(list(tp_groups)) + for tp in tp_groups: + qweight = target_weight[f"{module_name}.{tp}.qweight"] + meta = target_weight[f"{module_name}.{tp}.meta"] + scale = target_weight[f"{module_name}.{tp}.scales"] + output = matmul_4bit_2_4(qweight, x, meta, scale) + tp_outputs.append(output) + tp_output = torch.cat(tp_outputs, dim=1) + + print(f"reference_output: {reference_output.shape}, tp_output: {tp_output.shape}") + print(f"first half reference_out: \n{reference_output[:, :reference_output.size(1)//2]}\nfirst half tp_out: \n{tp_output[:, :tp_output.size(1)//2]}") + + print(f"second half reference_out: \n{reference_output[:, reference_output.size(1)//2:]}\nsecond half tp_out: \n{tp_output[:, tp_output.size(1)//2:]}") + + print(f"reference_output: \n{reference_output}\ntp_output: \n{tp_output}") + + print(f"max diff: {torch.max(torch.abs(reference_output - tp_output))}") + +def verify(args): + print(args) + weights = read_tensors(args.input, device='cuda') + reference_weights = read_tensors(args.reference_input, device='cuda') + check_output(weights, reference_weights, "model.layers.9.self_attn.qkv_proj") + # check_tp_group_equal(weights, reference_weights) + +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, help="Path to the input file") + parser.add_argument("--reference-input", default="", type=str, help="Path to the input file") + args = parser.parse_args() + verify(args) \ No newline at end of file