refactor data preprocessing, fix mllm rlhf

Former-commit-id: 53ff2dd24f9121ea30c95063bb72e49a9b31e980
This commit is contained in:
hiyouga
2024-05-24 04:08:25 +08:00
parent 1078611259
commit bf59383783
15 changed files with 572 additions and 464 deletions

View File

@@ -61,7 +61,7 @@ class HuggingfaceEngine(BaseEngine):
and image is not None
and not hasattr(processor, "image_seq_length")
and IMAGE_TOKEN not in messages[0]["content"]
): # llava case
): # llava-like models
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
@@ -74,7 +74,7 @@ class HuggingfaceEngine(BaseEngine):
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma case
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids