modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
0
src/llmtuner/extras/__init__.py
Normal file
0
src/llmtuner/extras/__init__.py
Normal file
72
src/llmtuner/extras/callbacks.py
Normal file
72
src/llmtuner/extras/callbacks.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.tracker = {}
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "current_steps" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
||||
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
||||
7
src/llmtuner/extras/constants.py
Normal file
7
src/llmtuner/extras/constants.py
Normal file
@@ -0,0 +1,7 @@
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||
18
src/llmtuner/extras/logging.py
Normal file
18
src/llmtuner/extras/logging.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
105
src/llmtuner/extras/misc.py
Normal file
105
src/llmtuner/extras/misc.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 0] = 1.0
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
finetuning_type: str,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> PreTrainedModel:
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
return model
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
50
src/llmtuner/extras/ploting.py
Normal file
50
src/llmtuner/extras/ploting.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
for i in range(len(data["log_history"])):
|
||||
if key in data["log_history"][i]:
|
||||
steps.append(data["log_history"][i]["step"])
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
||||
49
src/llmtuner/extras/save_and_load.py
Normal file
49
src/llmtuner/extras/save_and_load.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
state_dict = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
return True
|
||||
173
src/llmtuner/extras/template.py
Normal file
173
src/llmtuner/extras/template.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
name: str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.name == "vanilla":
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="{query}",
|
||||
sep="",
|
||||
use_history=False
|
||||
)
|
||||
|
||||
elif self.name == "default":
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "alpaca":
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.",
|
||||
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "vicuna":
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="USER: {query} ASSISTANT: ",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "belle":
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="Human: {query}\n\nBelle: ",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "linly":
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="User: {query}\nBot: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "billa":
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "ziya":
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<human>:{query}\n<bot>:",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "aquila":
|
||||
r"""
|
||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
prompt="Human: {query}###Assistant: ",
|
||||
sep="###",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "intern":
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
||||
sep="<eoa>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "baichuan":
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Template {} does not exist.".format(self.name))
|
||||
|
||||
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
|
||||
r"""
|
||||
Returns a string containing prompt without response.
|
||||
"""
|
||||
return "".join(self._format_example(query, history, prefix))
|
||||
|
||||
def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
||||
r"""
|
||||
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
||||
"""
|
||||
return self._format_example(query, history, prefix) + [resp]
|
||||
|
||||
def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
|
||||
self.prefix = prefix
|
||||
self.prompt = prompt
|
||||
self.sep = sep
|
||||
self.use_history = use_history
|
||||
|
||||
def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "<dummy>")]
|
||||
convs = []
|
||||
for turn_idx, (user_query, bot_resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
convs.append(prefix + self.prompt.format(query=user_query))
|
||||
convs.append(bot_resp)
|
||||
else:
|
||||
convs.append(self.sep + self.prompt.format(query=user_query))
|
||||
convs.append(bot_resp)
|
||||
return convs[:-1] # drop last
|
||||
Reference in New Issue
Block a user