Merge branch 'main' into minicpmv

Former-commit-id: d8840ae416660e23f1d615ffd404f519360151d9
This commit is contained in:
Zhangchi Feng
2025-01-10 20:12:07 +08:00
committed by GitHub
41 changed files with 647 additions and 357 deletions

View File

@@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.")
logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

View File

@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count
from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@@ -87,7 +87,7 @@ def main():
export_model()
elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,

View File

@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data
from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
@@ -84,7 +83,7 @@ def _load_single_dataset(
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
@@ -103,7 +102,7 @@ def _load_single_dataset(
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore

View File

@@ -75,10 +75,14 @@ class BasePlugin:
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
)
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
raise ValueError(
"This model does not support video input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""

View File

@@ -15,10 +15,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
@@ -44,7 +44,6 @@ class Template:
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
@@ -113,9 +112,6 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
@@ -180,9 +176,6 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
@@ -210,7 +203,6 @@ def _register_template(
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
@@ -224,34 +216,28 @@ def _register_template(
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
<s><user>user prompt here
<model>model response here</s>
<user>user prompt here
<model>model response here</s>
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
format_prefix=EmptyFormatter("<s>"),
)
```
"""
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
@@ -260,7 +246,6 @@ def _register_template(
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
@@ -344,9 +329,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
)
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
@@ -365,7 +348,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
@@ -411,7 +394,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
default_system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
@@ -423,13 +406,13 @@ _register_template(
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
format_assistant=StringFormatter(slots=["{{content}}###"]),
format_system=StringFormatter(slots=["System: {{content}}###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
efficient_eos=True,
)
@@ -459,7 +442,7 @@ _register_template(
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -481,7 +464,6 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
)
@@ -506,9 +488,9 @@ _register_template(
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
@@ -519,9 +501,9 @@ _register_template(
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
@@ -574,9 +556,11 @@ _register_template(
)
# copied from chatml template
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
@@ -603,9 +587,9 @@ _register_template(
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
@@ -622,7 +606,6 @@ _register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
@@ -644,8 +627,7 @@ _register_template(
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
@@ -659,8 +641,8 @@ _register_template(
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}\n"]),
)
@@ -673,22 +655,22 @@ _register_template(
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True,
)
@@ -696,12 +678,11 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
@@ -726,8 +707,8 @@ _register_template(
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
@@ -742,22 +723,20 @@ _register_template(
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<eoa>"],
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
@@ -888,6 +867,7 @@ _register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
@@ -900,16 +880,15 @@ _register_template(
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
@@ -918,10 +897,9 @@ _register_template(
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
@@ -943,6 +921,7 @@ _register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
@@ -955,10 +934,9 @@ _register_template(
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
@@ -967,16 +945,15 @@ _register_template(
_register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
@@ -984,6 +961,7 @@ _register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
@@ -1017,7 +995,6 @@ _register_template(
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
@@ -1025,9 +1002,9 @@ _register_template(
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
)
@@ -1044,12 +1021,11 @@ _register_template(
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
@@ -1057,28 +1033,37 @@ _register_template(
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="phi4",
format_user=StringFormatter(
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"],
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
@@ -1088,13 +1073,13 @@ _register_template(
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
)
@@ -1104,13 +1089,13 @@ _register_template(
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
@@ -1120,8 +1105,8 @@ _register_template(
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
@@ -1173,10 +1158,9 @@ _register_template(
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
)
@@ -1239,8 +1223,8 @@ _register_template(
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
@@ -1260,17 +1244,16 @@ _register_template(
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
@@ -1287,9 +1270,8 @@ _register_template(
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
@@ -1304,5 +1286,5 @@ _register_template(
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
)

View File

@@ -1424,6 +1424,14 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
"Phi-3.5-4B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
},
"Phi-3.5-MoE-42B-A6.6B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
},
},
template="phi",
)
@@ -1444,6 +1452,17 @@ register_model_group(
)
register_model_group(
models={
"Phi-4-14B-Instruct": {
DownloadSource.DEFAULT: "microsoft/phi-4",
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
},
},
template="phi4",
)
register_model_group(
models={
"Pixtral-12B-Instruct": {

View File

@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
A logger that supports rank0 logging.
"""
def info_rank0(self, *args, **kwargs) -> None:
@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
logging.Logger.warning_rank0_once = warning_rank0_once

View File

@@ -73,19 +73,31 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
@@ -229,7 +241,7 @@ def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
@@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
@@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
@@ -275,8 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available():
return _is_package_available("requests")

View File

@@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [
@@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args",
"get_infer_args",
"get_ray_args",
"get_train_args",
"read_args",
]

View File

@@ -15,56 +15,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
import yaml
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None:
return parser.parse_dict(args)
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return sys.argv[1:]
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if unknown_args:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
@@ -110,58 +121,61 @@ def _verify_model_args(
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
training_args: Optional["TrainingArguments"] = None,
) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
check_version("galore_torch", mandatory=True)
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
check_version("adam-mini", mandatory=True)
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
@@ -371,7 +385,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
@@ -404,7 +418,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()

View File

@@ -0,0 +1,48 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@@ -15,9 +15,9 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
if TYPE_CHECKING:
@@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"

View File

@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

View File

@@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.misc import check_version
if TYPE_CHECKING:
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)

View File

@@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -26,11 +26,10 @@ from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.misc import get_current_device
from ...extras.misc import check_version, get_current_device
if TYPE_CHECKING:
@@ -118,15 +117,15 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("auto_gptq>=0.5.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
check_version("autoawq", mandatory=True)
if quant_method == QuantizationMethod.AQLM:
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
@@ -136,8 +135,8 @@ def configure_quantization(
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
@@ -154,10 +153,10 @@ def configure_quantization(
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
@@ -175,7 +174,7 @@ def configure_quantization(
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
check_version("bitsandbytes>=0.43.0", mandatory=True)
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
@@ -187,7 +186,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq")
check_version("hqq", mandatory=True)
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
@@ -199,6 +198,6 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq")
check_version("eetq", mandatory=True)
init_kwargs["quantization_config"] = EetqConfig()
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@@ -35,7 +35,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory
from ..extras.misc import get_peak_memory, use_ray
if is_safetensors_available():
@@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
self.do_train = False
# Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
if self.webui_mode and not use_ray():
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler)
@@ -239,7 +239,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning_once("Previous trainer log in this folder will be deleted.")
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
@@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
)
if self.finetuning_args.use_swanlab:
import swanlab
import swanlab # type: ignore
swanlab.config.update(
{

View File

@@ -31,7 +31,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
@@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])

View File

@@ -30,7 +30,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
@@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r"""
Runs forward pass and computes the log probabilities.
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],

View File

@@ -122,7 +122,7 @@ def run_sft(
# Predict
if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)

View File

@@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
@@ -30,20 +32,25 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available
from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -74,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
@@ -187,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@@ -271,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
@@ -311,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
@@ -330,7 +337,7 @@ def _create_badam_optimizer(
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
@@ -350,7 +357,7 @@ def _create_badam_optimizer(
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
@@ -372,9 +379,9 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
@@ -397,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
@@ -414,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
@@ -459,12 +466,33 @@ def get_batch_logps(
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab
from swanlab.integration.transformers import SwanLabCallback
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
@@ -477,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@@ -22,7 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
@@ -31,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_swanlab_callback
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
@@ -41,10 +46,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
callbacks.append(LogCallback())
def _training_function(config: Dict[str, Any]) -> None:
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
@@ -69,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)

View File

@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
@@ -394,12 +394,12 @@ class Runner:
continue
if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_path, "all_results.json")):
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]