fix llamaboard with ray
Former-commit-id: bd8a432d6a980b1b24a551626304fe3d394b1baf
This commit is contained in:
@@ -46,11 +46,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def training_function(config: Dict[str, Any]) -> None:
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
if finetuning_args.pissa_convert:
|
||||
callbacks.append(PissaConvertCallback())
|
||||
|
||||
@@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(LogCallback())
|
||||
|
||||
args = read_args(args)
|
||||
ray_args = get_ray_args(args)
|
||||
callbacks = callbacks or []
|
||||
if ray_args.use_ray:
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=training_function,
|
||||
training_function=_training_function,
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
training_function(config={"args": args, "callbacks": callbacks})
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
||||
Reference in New Issue
Block a user