[model] add mistral small models (#6786)

Former-commit-id: e5e95c39bc4199fa89c67e34f9adaaa987058744
This commit is contained in:
hoshi-hiyouga
2025-02-01 04:31:38 +08:00
committed by GitHub
parent 800de98dc8
commit a28261a866
10 changed files with 106 additions and 32 deletions

View File

@@ -20,7 +20,7 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.48.1,!=4.46.*,!=4.47.*,!=4.48.0
transformers>=4.41.2,<=4.48.2,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
@@ -30,7 +30,7 @@ Dependency graph:
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0,<=4.48.1
transformers>=4.43.0,<=4.48.2
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1

View File

@@ -183,8 +183,8 @@ class HuggingfaceEngine(BaseEngine):
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer
gen_kwargs.pop("image_sizes", None)
return gen_kwargs, prompt_length

View File

@@ -319,7 +319,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
@@ -370,7 +370,7 @@ class LlavaNextVideoPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
@@ -915,7 +915,7 @@ class VideoLlavaPlugin(BasePlugin):
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1

View File

@@ -220,6 +220,7 @@ def _register_template(
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
fuse_system_into_user: bool = False,
) -> None:
r"""
Registers a chat template.
@@ -242,7 +243,7 @@ def _register_template(
)
```
"""
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
template_class = Llama2Template if fuse_system_into_user else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
@@ -751,6 +752,7 @@ _register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
fuse_system_into_user=True,
)
@@ -760,6 +762,7 @@ _register_template(
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
fuse_system_into_user=True,
)
@@ -878,11 +881,12 @@ _register_template(
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
fuse_system_into_user=True,
)
@@ -932,11 +936,12 @@ _register_template(
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
fuse_system_into_user=True,
)
@@ -978,15 +983,42 @@ _register_template(
)
# mistral tokenizer v3 tekken
_register_template(
name="ministral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
fuse_system_into_user=True,
)
# mistral tokenizer v3
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
fuse_system_into_user=True,
)
# mistral tokenizer v7 tekken (copied from ministral)
_register_template(
name="mistral_small",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -1081,12 +1113,17 @@ _register_template(
)
# copied from ministral template
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
fuse_system_into_user=True,
)

View File

@@ -1201,23 +1201,44 @@ register_model_group(
register_model_group(
models={
"MiniCPM-o-2_6-Chat": {
"MiniCPM-o-2_6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
},
},
template="minicpm_v",
vision=True,
)
register_model_group(
models={
"MiniCPM-V-2_6-Chat": {
"MiniCPM-V-2_6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
},
template="minicpm_v",
vision=True,
)
register_model_group(
models={
"Ministral-8B-Instruct-2410": {
DownloadSource.DEFAULT: "mistralai/Ministral-8B-Instruct-2410",
DownloadSource.MODELSCOPE: "mistralai/Ministral-8B-Instruct-2410",
},
"Mistral-Nemo-Base-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Base-2407",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-Nemo-Base-2407",
},
"Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="ministral",
)
@@ -1227,48 +1248,60 @@ register_model_group(
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/mistral-7b-v0.3",
},
"Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-Instruct-v0.2": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-Instruct-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
},
"Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="mistral",
)
register_model_group(
models={
"Mistral-Small-24B-Base-2501": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2501",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2501",
},
"Mistral-Small-24B-Instruct-2501": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2501",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2501",
},
},
template="mistral_small",
)
register_model_group(
models={
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",

View File

@@ -94,7 +94,7 @@ def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
check_version("transformers>=4.41.2,<=4.48.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("transformers>=4.41.2,<=4.48.2,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
check_version("peft>=0.11.1,<=0.12.0")

View File

@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
check_version("transformers>=4.43.0,<=4.48.1")
check_version("transformers>=4.43.0,<=4.48.2")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")