use pre-commit

Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
hiyouga
2024-10-29 09:07:46 +00:00
parent 8f5921692e
commit 248d5daaff
66 changed files with 1028 additions and 1044 deletions

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
@@ -74,7 +73,7 @@ def calculate_lr(
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError("Stage does not supported: {}.".format(stage))
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))
raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu(
@@ -140,10 +139,10 @@ def calculate_mfu(
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
@@ -157,7 +156,7 @@ def calculate_mfu(
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print("MFU: {:.2f}%".format(mfu_value * 100))
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -100,7 +99,7 @@ def calculate_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError("Stage does not supported: {}.".format(stage))
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
@@ -125,8 +124,8 @@ def calculate_ppl(
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
print("Perplexities have been saved at {}.".format(save_name))
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -61,7 +60,7 @@ def length_cdf(
for length, count in length_tuples:
count_accu += count
prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
@@ -40,7 +39,7 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion(
@@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict()
if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand
layer_cnt = 0
output_state_dict = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i))
print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1
for key, value in state_dict.items():
@@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand))
print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true")

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2(

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key:
llama2_state_dict[key] = value
else:
raise KeyError("Unable to process key {}".format(key))
raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict()
@@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen(

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
@@ -70,19 +69,19 @@ def quantize_loftq(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
print(f"Adapter weights saved in {loftq_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
@@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
)
# Init PiSSA model
@@ -65,17 +64,17 @@ def quantize_pissa(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
print(f"Adapter weights saved in {pissa_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");