Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -42,12 +55,8 @@ def _encode_feedback_example(
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
@@ -57,6 +66,12 @@ def _encode_feedback_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
# do not consider the kl_response
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
response_ids = response_ids[:target_len]
|
||||
kl_response_ids = kl_response_ids[:target_len]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -31,12 +44,8 @@ def _encode_pairwise_example(
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
@@ -46,6 +55,13 @@ def _encode_pairwise_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||
) # consider the response is more important
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
|
||||
@@ -1,9 +1,26 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
@@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
if target_len * 2 < cutoff_len: # truncate source
|
||||
max_target_len = cutoff_len
|
||||
elif source_len * 2 < cutoff_len: # truncate target
|
||||
max_target_len = cutoff_len - source_len
|
||||
else: # truncate both
|
||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||
|
||||
new_target_len = min(max_target_len, target_len)
|
||||
new_source_len = max(cutoff_len - new_target_len, 0)
|
||||
return new_source_len, new_target_len
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -38,10 +51,17 @@ def _encode_supervised_example(
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= data_args.cutoff_len:
|
||||
break
|
||||
|
||||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
||||
source_ids = source_ids[:source_len]
|
||||
target_ids = target_ids[:target_len]
|
||||
total_length += source_len + target_len
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -34,9 +47,7 @@ def _encode_unsupervised_example(
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
@@ -44,6 +55,9 @@ def _encode_unsupervised_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user