fix version checking
Former-commit-id: 5780da8d640609cca388f55983d0251e5547209a
This commit is contained in:
33
scripts/cal_flops.py
Normal file
33
scripts/cal_flops.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# coding=utf-8
|
||||
# Calculates the flops of pre-trained models.
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
|
||||
def calculate_flops(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
flash_attn: Optional[bool] = False,
|
||||
):
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||
print("FLOPs:", flops)
|
||||
print("MACs:", macs)
|
||||
print("Params:", params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(calculate_flops)
|
||||
77
scripts/cal_lr.py
Normal file
77
scripts/cal_lr.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
|
||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||
BASE_BS = 4_000_000 # from llama paper
|
||||
|
||||
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: Optional[str] = "sft",
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage=stage,
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||
)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
for batch in tqdm(dataloader):
|
||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||
total_tokens += torch.numel(batch["labels"])
|
||||
|
||||
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||
valid_ratio = valid_tokens / total_tokens
|
||||
batch_valid_len = batch_max_len * valid_ratio
|
||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||
lr = lr / 6.0 if is_mistral else lr
|
||||
print(
|
||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||
lr, valid_ratio * 100, batch_valid_len
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(calculate_lr)
|
||||
52
scripts/length_cdf.py
Normal file
52
scripts/length_cdf.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Calculates the distribution of the input lengths in the dataset.
|
||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
|
||||
def length_cdf(
|
||||
model_name_or_path: str,
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
interval: Optional[int] = 1000,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=1_000_000,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
length_dict[len(sample) // interval * interval] += 1
|
||||
|
||||
length_tuples = list(length_dict.items())
|
||||
length_tuples.sort()
|
||||
count_accu, prob_accu = 0, 0
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(length_cdf)
|
||||
115
scripts/llama_pro.py
Normal file
115
scripts/llama_pro.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Performs block expansion for LLaMA, Mistral or Qwen1.5 models.
|
||||
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||
|
||||
|
||||
def block_expansion(
|
||||
model_name_or_path: str,
|
||||
output_dir: str,
|
||||
num_expand: int,
|
||||
shard_size: Optional[str] = "2GB",
|
||||
save_safetensors: Optional[bool] = False,
|
||||
):
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
|
||||
if save_safetensors:
|
||||
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
|
||||
|
||||
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||
|
||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||
layer_cnt += 1
|
||||
if (i + 1) % split == 0:
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if "down_proj" in key or "o_proj" in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||
else:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||
|
||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||
layer_cnt += 1
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key not in output_state_dict:
|
||||
output_state_dict[key] = value
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("Fine-tune this model with:")
|
||||
print(" --model_name_or_path {} \\".format(output_dir))
|
||||
print(" --finetuning_type freeze \\")
|
||||
print(" --name_module_trainable all \\")
|
||||
print(" --num_layer_trainable {} \\".format(num_expand))
|
||||
print(" --use_llama_pro")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(block_expansion)
|
||||
92
scripts/llamafy_baichuan2.py
Normal file
92
scripts/llamafy_baichuan2.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# coding=utf-8
|
||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||
baichuan2_state_dict.update(shard_weight)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
|
||||
if "W_pack" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
||||
else:
|
||||
llama2_state_dict[key] = value
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict.pop("auto_map", None)
|
||||
llama2_config_dict.pop("tokenizer_class", None)
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||
save_config(input_dir, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_baichuan2)
|
||||
114
scripts/llamafy_internlm2.py
Normal file
114
scripts/llamafy_internlm2.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# coding=utf-8
|
||||
# Converts the InternLM2 model in the same format as LLaMA2.
|
||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
|
||||
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||
internlm2_state_dict.update(shard_weight)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
|
||||
if "output" in key:
|
||||
llama2_state_dict[key.replace("output", "lm_head")] = value
|
||||
elif "tok_embeddings" in key:
|
||||
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
|
||||
elif "wqkv" in key:
|
||||
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
||||
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
||||
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
||||
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
|
||||
q_size : q_size + kv_size, ...
|
||||
]
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
|
||||
elif "wo" in key:
|
||||
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
||||
elif "attention_norm" in key:
|
||||
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
||||
elif "ffn_norm" in key:
|
||||
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
|
||||
elif "w1" in key:
|
||||
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
|
||||
elif "w2" in key:
|
||||
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
|
||||
elif "w3" in key:
|
||||
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
|
||||
else:
|
||||
llama2_state_dict[key] = value
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict.pop("auto_map", None)
|
||||
llama2_config_dict.pop("bias", None)
|
||||
llama2_config_dict.pop("rope_scaling", None)
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_internlm2(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||
save_config(input_dir, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_internlm2)
|
||||
144
scripts/llamafy_qwen.py
Normal file
144
scripts/llamafy_qwen.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# coding=utf-8
|
||||
# Converts the Qwen models in the same format as LLaMA2.
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
|
||||
|
||||
try:
|
||||
check_min_version("4.34.0")
|
||||
except Exception:
|
||||
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
|
||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
qwen_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
torch_dtype = None
|
||||
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
|
||||
if torch_dtype is None:
|
||||
torch_dtype = value.dtype
|
||||
if "wte" in key:
|
||||
llama2_state_dict["model.embed_tokens.weight"] = value
|
||||
elif "ln_f" in key:
|
||||
llama2_state_dict["model.norm.weight"] = value
|
||||
else:
|
||||
key = key.replace("transformer.h", "model.layers")
|
||||
if "attn.c_attn" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
|
||||
proj_size : 2 * proj_size, ...
|
||||
]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
|
||||
elif "attn.c_proj" in key:
|
||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
|
||||
value[:, 0]
|
||||
).squeeze()
|
||||
elif "ln_1" in key:
|
||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||
elif "ln_2" in key:
|
||||
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
|
||||
elif "mlp.w1" in key:
|
||||
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
|
||||
elif "mlp.w2" in key:
|
||||
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
|
||||
elif "mlp.c_proj" in key:
|
||||
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict["hidden_act"] = "silu"
|
||||
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
|
||||
llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
|
||||
llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
|
||||
llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
|
||||
llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
|
||||
llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
|
||||
llama2_config_dict["pretraining_tp"] = 1
|
||||
llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
|
||||
llama2_config_dict["rope_scaling"] = None
|
||||
llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
|
||||
llama2_config_dict["torch_dtype"] = torch_dtype
|
||||
llama2_config_dict["transformers_version"] = "4.34.0"
|
||||
llama2_config_dict["use_cache"] = True
|
||||
llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
|
||||
llama2_config_dict["attention_bias"] = True
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||
save_config(input_dir, output_dir, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_qwen)
|
||||
82
scripts/loftq_init.py
Normal file
82
scripts/loftq_init.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# coding=utf-8
|
||||
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class Shell(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
child_name = name.split(".")[-1]
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
child_module = getattr(parent_module, child_name)
|
||||
base_layer = getattr(child_module, "base_layer")
|
||||
weight = getattr(base_layer, "weight", None)
|
||||
bias = getattr(base_layer, "bias", None)
|
||||
setattr(parent_module, child_name, Shell(weight, bias))
|
||||
|
||||
print("Model unwrapped.")
|
||||
|
||||
|
||||
def quantize_loftq(
|
||||
model_name_or_path: str,
|
||||
save_dir: str,
|
||||
loftq_bits: Optional[int] = 4,
|
||||
loftq_iter: Optional[int] = 1,
|
||||
lora_alpha: Optional[int] = None,
|
||||
lora_rank: Optional[int] = 16,
|
||||
lora_target: Optional[str] = "q_proj,v_proj",
|
||||
save_safetensors: Optional[bool] = False,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=True,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=0.1,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
)
|
||||
|
||||
# Init LoftQ model
|
||||
lora_model = get_peft_model(model, lora_config)
|
||||
base_model: "PreTrainedModel" = lora_model.get_base_model()
|
||||
|
||||
# Save LoftQ model
|
||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
|
||||
|
||||
# Save base model
|
||||
unwrap_model(base_model)
|
||||
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(quantize_loftq)
|
||||
57
scripts/test_toolcall.py
Normal file
57
scripts/test_toolcall.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||
|
||||
|
||||
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
||||
grade_to_score = {"A": 4, "B": 3, "C": 2}
|
||||
total_score, total_hour = 0, 0
|
||||
for grade, hour in zip(grades, hours):
|
||||
total_score += grade_to_score[grade] * hour
|
||||
total_hour += hour
|
||||
return total_score / total_hour
|
||||
|
||||
|
||||
tool_map = {"calculate_gpa": calculate_gpa}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = OpenAI(
|
||||
api_key="0",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate_gpa",
|
||||
"description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
|
||||
"hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
|
||||
},
|
||||
"required": ["grades", "hours"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||
tool_call = result.choices[0].message.tool_calls[0].function
|
||||
name, arguments = tool_call.name, json.loads(tool_call.arguments)
|
||||
messages.append(
|
||||
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
|
||||
)
|
||||
tool_result = tool_map[name](**arguments)
|
||||
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
|
||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||
print(result.choices[0].message.content)
|
||||
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
|
||||
Reference in New Issue
Block a user