@@ -10,7 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
@@ -235,6 +235,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
@@ -278,8 +279,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None:
|
||||
model_args.device_map = {"": "cpu"}
|
||||
model_args.compute_dtype = torch.float32
|
||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user