Merge branch 'main' into main

Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
hoshi-hiyouga
2024-07-01 21:01:09 +08:00
committed by GitHub
176 changed files with 4760 additions and 1322 deletions

View File

@@ -1,217 +0,0 @@
import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, output_dir: str) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
self.aborted = False
self.do_train = False
""" Web UI """
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def _create_thread_pool(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _close_thread_pool(self) -> None:
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
)
)
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if has_length(eval_dataloader):
if self.max_steps == 0:
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)
self._timing(cur_steps=self.cur_steps + 1)
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
@@ -404,6 +418,18 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
"DeepSeek-MoE-Coder-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
},
"DeepSeek-MoE-Coder-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
},
"DeepSeek-MoE-Coder-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
},
"DeepSeek-MoE-Coder-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
},
},
template="deepseek",
)
@@ -496,6 +522,18 @@ register_model_group(
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
"Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b",
},
"Gemma-2-27B": {
DownloadSource.DEFAULT: "google/gemma-2-27b",
},
"Gemma-2-9B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
},
"Gemma-2-27B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
},
},
template="gemma",
)
@@ -568,7 +606,7 @@ register_model_group(
register_model_group(
models={
"Jambda-v0.1": {
"Jamba-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
@@ -683,6 +721,21 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
},
"MiniCPM-2B-DPO-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
},
},
template="cpm",
)
register_model_group(
models={
"Mistral-7B-v0.1": {

View File

@@ -1,3 +1,20 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import accelerate
@@ -9,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.8.1.dev0"
VERSION = "0.8.3.dev0"
def print_env() -> None:

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys

View File

@@ -1,13 +1,29 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
@@ -16,7 +32,6 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
@@ -28,8 +43,6 @@ except Exception:
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
@@ -58,6 +71,9 @@ class AverageMeter:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
@@ -68,7 +84,7 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
@@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
@@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
def get_current_device() -> torch.device:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
@@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
@@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: os.PathLike) -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def skip_check_imports() -> None:
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:

View File

@@ -1,5 +1,23 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
@@ -24,10 +42,6 @@ def is_fastapi_available():
return _is_package_available("fastapi")
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available():
return _is_package_available("galore_torch")
@@ -36,18 +50,10 @@ def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
def is_matplotlib_available():
return _is_package_available("matplotlib")
def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available():
return _is_package_available("PIL")
@@ -60,10 +66,6 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available():
return _is_package_available("sse_starlette")
@@ -74,3 +76,8 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os