Former-commit-id: 9a30ee5009040afbc524dbac0dad99904b2adf5f
This commit is contained in:
@@ -5,9 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -27,14 +25,18 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
@@ -230,7 +230,8 @@ class LlamaAttention(torch.nn.Module):
|
||||
new_len = past_len+q.size(1)
|
||||
if new_len > past_kv.size(1):
|
||||
past_kv = torch.cat(
|
||||
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1
|
||||
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
|
||||
dim=1
|
||||
)
|
||||
past_kv[:, past_len:new_len] = kv
|
||||
kv = past_kv[:, :new_len]
|
||||
@@ -248,20 +249,18 @@ class LlamaAttention(torch.nn.Module):
|
||||
attn_outputs = flash_attn_varlen_kvpacked_func(
|
||||
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
||||
dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
|
||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
||||
)
|
||||
|
||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
||||
attn_output = pad_input(
|
||||
attn_output, indices_q, bsz, q_len
|
||||
).reshape(bsz, q_len, h_size)
|
||||
attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
|
||||
attn_weights = attn_outputs[2] if output_attentions else None
|
||||
|
||||
else:
|
||||
# no padding tokens, more efficient
|
||||
attn_outputs = flash_attn_kvpacked_func(
|
||||
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
||||
q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
|
||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
||||
)
|
||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
||||
Reference in New Issue
Block a user