Skip to content

Commit

Permalink
Update convert_grok.py to use logging module
Browse files Browse the repository at this point in the history
  • Loading branch information
mofosyne authored May 9, 2024
1 parent c9c8952 commit b16d543
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

import gguf

logger = logging.getLogger("convert_grok")

GGML_QK8_0 = 32
GGML_QK4_0 = 32
GGML_QK4_1 = 32
Expand Down Expand Up @@ -214,7 +216,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
tensor_ggml_type,
)
weights[name] = weight, scales
print("Loaded", len(weight_names), "files")
logger.info("Loaded", len(weight_names), "files")

f.write_header_to_file()
f.write_kv_data_to_file()
Expand All @@ -230,7 +232,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()

print(
logger.debug(
f"dumping {name}:",
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
)
Expand All @@ -239,12 +241,12 @@ def dump_state_dict(f, ggml_type, input_dir, config):
tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))

try:
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) # noqa: NP100
except NameError:
pass

if len(tensor_info) != len(weight_names):
print("Warning: not all tensors are converted")
logger.warning("Not all tensors are converted")


def from_numpy(array):
Expand Down Expand Up @@ -377,7 +379,7 @@ def ffn_size(emb_size, widening_factor):
config.num_experts = len(config.experts)

assert config.num_experts >= 2, "need at least 2 experts"
print("experts to export:", config.experts)
logger.info("experts to export:", config.experts)

f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)

Expand Down Expand Up @@ -409,12 +411,12 @@ def ffn_size(emb_size, widening_factor):

delta = time.time() - start

print(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec")
logger.info(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec")


def load_vocab(path):
def load_spm(p):
print(f"Loading vocab file {p}")
logger.info(f"Loading vocab file {p}")
return SentencePieceVocab(p)

# Be extra-friendly and accept either a file or a directory. Also, if it's
Expand Down Expand Up @@ -445,8 +447,12 @@ def main():
)
parser.add_argument("--vocab_dir", type=str, default="")
parser.add_argument("--experts", type=str, default="")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")

args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

vocab = load_vocab(
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
)
Expand Down

0 comments on commit b16d543

Please sign in to comment.