use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
@@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
|
||||
Reference in New Issue
Block a user