drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
This commit is contained in:
committed by
hiyouga
parent
5ccc607222
commit
8683582300
@@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..hparams.parser import _parse_ray_args, _read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@@ -36,12 +37,14 @@ from .trainer_utils import get_swanlab_callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
def training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args", None)
|
||||
callbacks = config.get("callbacks", [])
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
@@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
|
||||
args_dict = _read_args(args)
|
||||
ray_args = _parse_ray_args(args_dict)
|
||||
|
||||
if ray_args.use_ray:
|
||||
# Import lazily to avoid ray not installed error
|
||||
from ..integrations.ray.ray_train import get_ray_trainer
|
||||
|
||||
# Initialize ray trainer
|
||||
trainer = get_ray_trainer(
|
||||
training_function=training_function,
|
||||
train_loop_config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
training_function(
|
||||
config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
Reference in New Issue
Block a user