use pre-commit

Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
hiyouga
2024-10-29 09:07:46 +00:00
parent 8f5921692e
commit 248d5daaff
66 changed files with 1028 additions and 1044 deletions

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2(