refactor ray integration, support save ckpt

Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
This commit is contained in:
hiyouga
2025-01-07 08:54:41 +00:00
parent 4f31ad997c
commit 944a2aec4d
18 changed files with 215 additions and 161 deletions

View File

@@ -22,8 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..hparams.parser import _parse_ray_args, _read_args
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
@@ -32,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_swanlab_callback
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
@@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None)
callbacks = config.get("callbacks", [])
callbacks.append(LogCallback())
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert:
@@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
callbacks = callbacks or []
callbacks.append(LogCallback())
args = read_args(args)
ray_args = get_ray_args(args)
if ray_args.use_ray:
# Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=training_function,
train_loop_config={
"args": args_dict,
"callbacks": callbacks,
},
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
training_function(
config={
"args": args_dict,
"callbacks": callbacks,
}
)
training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: