refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -39,9 +40,6 @@ def _encode_pairwise_example(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
@@ -51,10 +49,7 @@ def _encode_pairwise_example(
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@@ -77,27 +72,15 @@ def preprocess_pairwise_dataset(
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {
|
||||
"chosen_input_ids": [],
|
||||
"chosen_attention_mask": [],
|
||||
"chosen_labels": [],
|
||||
"rejected_input_ids": [],
|
||||
"rejected_attention_mask": [],
|
||||
"rejected_labels": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"] = []
|
||||
model_inputs["rejected_token_type_ids"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
@@ -112,15 +95,15 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
||||
)
|
||||
model_inputs["rejected_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
||||
)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={
|
||||
"chosen_token_type_ids": len(chosen_input_ids),
|
||||
"rejected_token_type_ids": len(rejected_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user