fix flashattn + packing
Former-commit-id: 4adc6ce4abc718c25f39b316bfc3352d0d01ed1e
This commit is contained in:
@@ -37,7 +37,7 @@ def _encode_pairwise_example(
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
@@ -55,9 +55,8 @@ def _encode_pairwise_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||
) # consider the response is more important
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
@@ -105,7 +104,7 @@ def preprocess_pairwise_dataset(
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
cutoff_len=data_args.cutoff_len,
|
||||
)
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
|
||||
Reference in New Issue
Block a user