support mllm hf inference
Former-commit-id: 2c7c01282acd7ddabbb17ce3246b8dae4bc4b8cf
This commit is contained in:
@@ -24,8 +24,9 @@ def run_dpo(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
|
||||
@@ -24,8 +24,9 @@ def run_orpo(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
|
||||
@@ -27,8 +27,9 @@ def run_ppo(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
|
||||
@@ -25,8 +25,9 @@ def run_pt(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
|
||||
@@ -25,8 +25,9 @@ def run_rm(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
|
||||
@@ -29,9 +29,9 @@ def run_sft(
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
@@ -52,7 +52,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||
raise ValueError("Please merge adapters before quantizing the model.")
|
||||
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ def create_ref_model(
|
||||
)
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
tokenizer = load_tokenizer(ref_model_args)
|
||||
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
@@ -100,7 +100,7 @@ def create_ref_model(
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
@@ -147,7 +147,7 @@ def create_reward_model(
|
||||
)
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
tokenizer = load_tokenizer(reward_model_args)
|
||||
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
|
||||
reward_model = load_model(
|
||||
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user