refactor pissa, improve llamaboard
Former-commit-id: 619556e46c19718f702c97df5d570a2a4c5fb13a
This commit is contained in:
@@ -1,231 +0,0 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors,
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def _set_abort(self, signum, frame) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _reset(self, max_steps: int = 0) -> None:
|
||||
self.start_time = time.time()
|
||||
self.cur_steps = 0
|
||||
self.max_steps = max_steps
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def _timing(self, cur_steps: int) -> None:
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
|
||||
self.cur_steps = cur_steps
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def _create_thread_pool(self, output_dir: str) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _close_thread_pool(self) -> None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
"""
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
self._timing(cur_steps=state.global_step)
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
loss=state.log_history[-1].get("loss", None),
|
||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||
reward=state.log_history[-1].get("reward", None),
|
||||
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
|
||||
total_tokens=state.num_input_tokens_seen,
|
||||
)
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
|
||||
)
|
||||
)
|
||||
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
if self.do_train:
|
||||
return
|
||||
|
||||
if self.aborted:
|
||||
sys.exit(0)
|
||||
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if has_length(eval_dataloader):
|
||||
if self.max_steps == 0:
|
||||
self._reset(max_steps=len(eval_dataloader))
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
@@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's PEFT library.
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,15 +17,11 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -31,15 +30,9 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
@@ -48,8 +41,6 @@ except Exception:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
@@ -99,7 +90,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
@@ -117,51 +108,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
r"""
|
||||
Gets the current available device.
|
||||
@@ -201,7 +147,7 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
@@ -220,7 +166,7 @@ def is_gpu_or_npu_available() -> bool:
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user