alter rewards data type
Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
@@ -42,8 +42,7 @@ from .other import (
|
||||
load_valuehead_params,
|
||||
print_trainable_params,
|
||||
prepare_model_for_training,
|
||||
IGNORE_INDEX,
|
||||
FINETUNING_ARGS_NAME
|
||||
IGNORE_INDEX
|
||||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
@@ -128,7 +127,7 @@ def init_adapter(
|
||||
|
||||
def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: Optional[FinetuningArguments] = None,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
@@ -137,16 +136,9 @@ def load_pretrained(
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if finetuning_args is None: # load the fine-tuning arguments
|
||||
if model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
||||
finetuning_args = FinetuningArguments.load_from_json(
|
||||
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
Reference in New Issue
Block a user