fix packing for eager/sdpa attn
Former-commit-id: 735a033ceb7f2da6da71d138ea091d8a665411a9
This commit is contained in:
@@ -79,9 +79,8 @@ def fix_valuehead_checkpoint(
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
@@ -91,6 +90,7 @@ def fix_valuehead_checkpoint(
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, split_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
@@ -54,10 +52,13 @@ def run_sft(
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
|
||||
Reference in New Issue
Block a user