update web UI, support rm predict #210
Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
@@ -54,7 +54,7 @@ def get_train_args(
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
|
||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -49,9 +50,9 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
else:
|
||||
backbone_model = model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
if isinstance(backbone_model, PeftModel): # LoRA tuning
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
else: # freeze/full tuning
|
||||
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
|
||||
backbone_model.config.use_cache = True
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
@@ -61,6 +62,8 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
backbone_model.config.use_cache = False
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
else:
|
||||
logger.warning("No model to save.")
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
@@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
|
||||
if isinstance(backbone_model, PeftModel):
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
|
||||
Reference in New Issue
Block a user