refactor ray integration, support save ckpt
Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
This commit is contained in:
@@ -18,7 +18,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -31,7 +32,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..extras.packages import is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
|
||||
@@ -40,11 +41,16 @@ if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -75,7 +81,7 @@ def create_modelcard_and_push(
|
||||
trainer: "Trainer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
kwargs = {
|
||||
@@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
|
||||
def _create_galore_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
@@ -272,7 +278,7 @@ def _create_galore_optimizer(
|
||||
|
||||
def _create_loraplus_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
default_lr = training_args.learning_rate
|
||||
@@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
|
||||
|
||||
def _create_badam_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
decay_params, nodecay_params = [], []
|
||||
@@ -373,7 +379,7 @@ def _create_badam_optimizer(
|
||||
|
||||
def _create_adam_mini_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from adam_mini import Adam_mini # type: ignore
|
||||
|
||||
@@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
|
||||
|
||||
def create_custom_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_galore:
|
||||
@@ -415,7 +421,7 @@ def create_custom_optimizer(
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
training_args: "TrainingArguments",
|
||||
num_training_steps: int,
|
||||
optimizer: Optional["torch.optim.Optimizer"] = None,
|
||||
) -> None:
|
||||
@@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
config={"Framework": "🦙LlamaFactory"},
|
||||
)
|
||||
return swanlab_callback
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.ray_num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
placement_strategy=ray_args.placement_strategy,
|
||||
use_gpu=True,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=ray_args.ray_run_name,
|
||||
storage_path=Path("./saves").absolute().as_posix(),
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
|
||||
@@ -22,8 +22,8 @@ from transformers import PreTrainedModel
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..hparams.parser import _parse_ray_args, _read_args
|
||||
from ..extras.packages import is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@@ -32,7 +32,11 @@ from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_swanlab_callback
|
||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args", None)
|
||||
callbacks = config.get("callbacks", [])
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
@@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
args_dict = _read_args(args)
|
||||
ray_args = _parse_ray_args(args_dict)
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(LogCallback())
|
||||
|
||||
args = read_args(args)
|
||||
ray_args = get_ray_args(args)
|
||||
if ray_args.use_ray:
|
||||
# Import lazily to avoid ray not installed error
|
||||
from ..integrations.ray.ray_train import get_ray_trainer
|
||||
|
||||
# Initialize ray trainer
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=training_function,
|
||||
train_loop_config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
training_function(
|
||||
config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
)
|
||||
training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
||||
Reference in New Issue
Block a user