refactor ray integration, support save ckpt

Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
This commit is contained in:
hiyouga
2025-01-07 08:54:41 +00:00
parent 4f31ad997c
commit 944a2aec4d
18 changed files with 215 additions and 161 deletions

View File

@@ -18,7 +18,8 @@
# limitations under the License.
from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
@@ -31,7 +32,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available
from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -40,11 +41,16 @@ if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -75,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
@@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@@ -272,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
@@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
@@ -373,7 +379,7 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini # type: ignore
@@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
@@ -415,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
@@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@@ -22,8 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..hparams.parser import _parse_ray_args, _read_args
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
@@ -32,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_swanlab_callback
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
@@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None)
callbacks = config.get("callbacks", [])
callbacks.append(LogCallback())
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert:
@@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
callbacks = callbacks or []
callbacks.append(LogCallback())
args = read_args(args)
ray_args = get_ray_args(args)
if ray_args.use_ray:
# Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=training_function,
train_loop_config={
"args": args_dict,
"callbacks": callbacks,
},
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
training_function(
config={
"args": args_dict,
"callbacks": callbacks,
}
)
training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: