run style check

Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
Eric Tang
2025-01-06 23:55:56 +00:00
committed by hiyouga
parent 8683582300
commit 4f31ad997c
7 changed files with 54 additions and 35 deletions

View File

@@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None)
callbacks = config.get("callbacks", [])
callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
@@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None:
else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict)
if ray_args.use_ray:
if ray_args.use_ray:
# Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
trainer = get_ray_trainer(
training_function=training_function,