fix shift short attention

Former-commit-id: 9a49cce8e6f6b222f74a07bdab40efee6a77b0f1
This commit is contained in:
hiyouga
2023-10-09 17:07:46 +08:00
parent 5c4248a29c
commit e387a50475
6 changed files with 46 additions and 52 deletions

View File

@@ -103,7 +103,6 @@ def load_model_and_tokenizer(
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
@@ -128,7 +127,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set FlashAttention-2 and S^2-Attn
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
@@ -136,12 +135,22 @@ def load_model_and_tokenizer(
LlamaPatches._prepare_decoder_attention_mask
)
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
@@ -176,14 +185,6 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(model, "shift_ratio", 0.25)
logger.info("Using shift short attention proposed by LongLoRA.")
else:
logger.warning("Current model does not support shift short attention.")
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)

View File

@@ -149,6 +149,9 @@ def get_train_args(
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if general_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")

View File

@@ -29,6 +29,7 @@ def run_dpo(
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)

View File

@@ -27,7 +27,7 @@ def run_rm(
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset

View File

@@ -33,6 +33,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=4, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)