mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-19 23:33:09 +00:00
Compare commits
11 Commits
9501c3308a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e67ab9e2f2 | ||
|
|
2c4f121817 | ||
|
|
487f8b8191 | ||
|
|
78cad1e332 | ||
|
|
70653026f5 | ||
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 |
@@ -473,7 +473,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.11 | >=3.11 |
|
||||||
| torch | 2.0.0 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| torchvision | 0.15.0 | 0.21.0 |
|
| torchvision | 0.15.0 | 0.21.0 |
|
||||||
| transformers | 4.49.0 | 4.50.0 |
|
| transformers | 4.49.0 | 4.50.0 |
|
||||||
|
|||||||
@@ -475,7 +475,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.11 | >=3.11 |
|
||||||
| torch | 2.0.0 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| torchvision | 0.15.0 | 0.21.0 |
|
| torchvision | 0.15.0 | 0.21.0 |
|
||||||
| transformers | 4.49.0 | 4.50.0 |
|
| transformers | 4.49.0 | 4.50.0 |
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# https://hub.docker.com/r/ascendai/cann/tags
|
# https://hub.docker.com/r/ascendai/cann/tags
|
||||||
|
|
||||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install torch-npu
|
# Install torch-npu
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
RUN pip uninstall -y torch torchvision torchaudio
|
||||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||||
|
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||||
|
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ services:
|
|||||||
dockerfile: ./docker/docker-npu/Dockerfile
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory-a3
|
container_name: llamafactory-a3
|
||||||
image: llamafactory:npu-a3
|
image: llamafactory:npu-a3
|
||||||
|
|||||||
@@ -28,12 +28,7 @@ save_only_model: false
|
|||||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
### ray
|
### ray
|
||||||
ray_run_name: qwen3_4b_sft_lora
|
|
||||||
ray_storage_path: ./saves
|
|
||||||
ray_num_workers: 4 # Number of GPUs to use.
|
ray_num_workers: 4 # Number of GPUs to use.
|
||||||
placement_strategy: PACK
|
|
||||||
resources_per_worker:
|
|
||||||
GPU: 1
|
|
||||||
# ray_init_kwargs:
|
# ray_init_kwargs:
|
||||||
# runtime_env:
|
# runtime_env:
|
||||||
# env_vars:
|
# env_vars:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
torch==2.7.1
|
torch==2.7.1
|
||||||
torch-npu==2.7.1
|
torch-npu==2.7.1.post2
|
||||||
torchvision==0.22.1
|
torchvision==0.22.1
|
||||||
torchaudio==2.7.1
|
torchaudio==2.7.1
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def convert(
|
|||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
expert_model_parallel_size: int = 1,
|
expert_model_parallel_size: int = 1,
|
||||||
virtual_pipeline_model_parallel_size: int | None = None,
|
virtual_pipeline_model_parallel_size: int | None = None,
|
||||||
|
moe_grouped_gemm: bool | None = None,
|
||||||
):
|
):
|
||||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
@@ -84,6 +85,10 @@ def convert(
|
|||||||
pipeline_model_parallel_size: Pipeline model parallel size
|
pipeline_model_parallel_size: Pipeline model parallel size
|
||||||
expert_model_parallel_size: Expert model parallel size
|
expert_model_parallel_size: Expert model parallel size
|
||||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||||
|
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||||
|
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||||
|
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||||
|
Must match the format used when saving the checkpoint.
|
||||||
"""
|
"""
|
||||||
if bf16 and fp16:
|
if bf16 and fp16:
|
||||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
@@ -97,8 +102,9 @@ def convert(
|
|||||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||||
expert_model_parallel_size=expert_model_parallel_size,
|
expert_model_parallel_size=expert_model_parallel_size,
|
||||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||||
|
moe_grouped_gemm=moe_grouped_gemm,
|
||||||
|
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_checkpoint_to_mca(
|
convert_checkpoint_to_mca(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
output_path,
|
output_path,
|
||||||
|
|||||||
@@ -88,7 +88,10 @@ def _process_request(
|
|||||||
|
|
||||||
if request.messages[0].role == Role.SYSTEM:
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
content = request.messages.pop(0).content
|
content = request.messages.pop(0).content
|
||||||
system = content[0].text if isinstance(content, list) else content
|
if isinstance(content, list):
|
||||||
|
system = content[0].text if content else ""
|
||||||
|
else:
|
||||||
|
system = content
|
||||||
else:
|
else:
|
||||||
system = None
|
system = None
|
||||||
|
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
|||||||
|
|
||||||
# filter out non-JSON files
|
# filter out non-JSON files
|
||||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||||
if not files:
|
if not files:
|
||||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||||
|
|
||||||
|
|||||||
@@ -161,7 +161,9 @@ class MMPluginMixin:
|
|||||||
video_processor: BaseImageProcessor = getattr(
|
video_processor: BaseImageProcessor = getattr(
|
||||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||||
)
|
)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
if len(images) != 0 and self.image_token is None:
|
if len(images) != 0 and self.image_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model does not support image input. Please check whether the correct `template` is used."
|
"This model does not support image input. Please check whether the correct `template` is used."
|
||||||
@@ -390,7 +392,9 @@ class MMPluginMixin:
|
|||||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||||
|
|
||||||
if len(audios) != 0:
|
if len(audios) != 0:
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
audios = self._regularize_audios(
|
audios = self._regularize_audios(
|
||||||
audios,
|
audios,
|
||||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
@@ -1054,7 +1058,9 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
chunk_input=True,
|
chunk_input=True,
|
||||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
)
|
)
|
||||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
audio_feature_lens = [
|
||||||
|
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||||
|
]
|
||||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||||
if kwargs.get("ret_phs", False):
|
if kwargs.get("ret_phs", False):
|
||||||
mm_inputs.update({"audio_phs": audio_phs})
|
mm_inputs.update({"audio_phs": audio_phs})
|
||||||
@@ -1094,7 +1100,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
@@ -1876,7 +1882,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@@ -1981,6 +1989,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
|
||||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||||
video_t_index = (
|
video_t_index = (
|
||||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||||
@@ -1992,9 +2001,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
.flatten()
|
.flatten()
|
||||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
* position_id_per_seconds
|
||||||
).long()
|
).long()
|
||||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
t_ntoken_per_chunk = position_id_per_seconds * 2
|
||||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||||
placeholder_string = ""
|
placeholder_string = ""
|
||||||
|
|||||||
@@ -1113,7 +1113,7 @@ register_template(
|
|||||||
register_template(
|
register_template(
|
||||||
name="gpt_oss",
|
name="gpt_oss",
|
||||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ MCA_SUPPORTED_MODELS = {
|
|||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
"qwen3_next",
|
"qwen3_next",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
}
|
}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|||||||
@@ -470,7 +470,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
training_args.resume_from_checkpoint is None
|
training_args.resume_from_checkpoint is None
|
||||||
and training_args.do_train
|
and training_args.do_train
|
||||||
and os.path.isdir(training_args.output_dir)
|
and os.path.isdir(training_args.output_dir)
|
||||||
and not training_args.overwrite_output_dir
|
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||||
and can_resume_from_checkpoint
|
and can_resume_from_checkpoint
|
||||||
):
|
):
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if (
|
if (
|
||||||
args.should_save
|
args.should_save
|
||||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
and args.overwrite_output_dir
|
and getattr(args, "overwrite_output_dir", False)
|
||||||
):
|
):
|
||||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
@@ -77,20 +79,25 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
|
|
||||||
def _check_model_support(model_args: "ModelArguments"):
|
def _check_model_support(model_args: "ModelArguments"):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||||
|
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||||
|
model_type = mca_config.get("hf_model_type", None)
|
||||||
|
else:
|
||||||
|
config = HfAutoConfig.from_pretrained(
|
||||||
|
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
config = HfAutoConfig.from_pretrained(
|
if model_type not in MCA_SUPPORTED_MODELS:
|
||||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
|
||||||
)
|
|
||||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {config.model_type} is not supported by mcore_adapter."
|
f"Model {model_type} is not supported by mcore_adapter."
|
||||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
|
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
params_to_freeze = []
|
params_to_freeze = []
|
||||||
|
|||||||
@@ -91,7 +91,11 @@ class Renderer:
|
|||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
def render_messages(
|
def render_messages(
|
||||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
"""Apply template to messages and convert them to model input.
|
"""Apply template to messages and convert them to model input.
|
||||||
|
|
||||||
@@ -99,6 +103,7 @@ class Renderer:
|
|||||||
messages (list[Message]): The messages to render.
|
messages (list[Message]): The messages to render.
|
||||||
tools (str | None, optional): The tools to use. Defaults to None.
|
tools (str | None, optional): The tools to use. Defaults to None.
|
||||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||||
|
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelInput: The rendered model input.
|
ModelInput: The rendered model input.
|
||||||
@@ -108,7 +113,9 @@ class Renderer:
|
|||||||
else:
|
else:
|
||||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||||
|
|
||||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
return RenderingPlugin(self.template).render_messages(
|
||||||
|
self.processor, messages, tools, is_generate, enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
def parse_message(self, generated_text: str) -> Message:
|
def parse_message(self, generated_text: str) -> Message:
|
||||||
"""Parse a message in the template format.
|
"""Parse a message in the template format.
|
||||||
|
|||||||
@@ -12,224 +12,45 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import importlib
|
||||||
import re
|
|
||||||
|
|
||||||
from ...utils.constants import IGNORE_INDEX
|
from ...utils import logging
|
||||||
from ...utils.helper import get_tokenizer
|
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
from ...utils.types import Message, ModelInput, Processor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RenderingPlugin(BasePlugin):
|
class RenderingPlugin(BasePlugin):
|
||||||
|
_attempted_template_imports: set[str] = set()
|
||||||
|
|
||||||
|
def _ensure_template_imported(self) -> None:
|
||||||
|
if self.name is None or self.name in self._attempted_template_imports:
|
||||||
|
return
|
||||||
|
|
||||||
|
full_module_name = f"{__package__}.templates.{self.name}"
|
||||||
|
self._attempted_template_imports.add(self.name)
|
||||||
|
try:
|
||||||
|
importlib.import_module(full_module_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
|
||||||
|
|
||||||
|
def __getitem__(self, method_name: str):
|
||||||
|
self._ensure_template_imported()
|
||||||
|
return super().__getitem__(method_name)
|
||||||
|
|
||||||
def render_messages(
|
def render_messages(
|
||||||
self,
|
self,
|
||||||
processor: Processor,
|
processor: Processor,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: str | None = None,
|
tools: str | None = None,
|
||||||
is_generate: bool = False,
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
"""Render messages in the template format."""
|
"""Render messages in the template format."""
|
||||||
return self["render_messages"](processor, messages, tools, is_generate)
|
return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
|
||||||
|
|
||||||
def parse_messages(self, generated_text: str) -> Message:
|
def parse_messages(self, generated_text: str) -> Message:
|
||||||
"""Parse messages in the template format."""
|
"""Parse messages in the template format."""
|
||||||
return self["parse_messages"](generated_text)
|
return self["parse_messages"](generated_text)
|
||||||
|
|
||||||
|
|
||||||
def _update_model_input(
|
|
||||||
processor: Processor,
|
|
||||||
input_ids: list[int],
|
|
||||||
labels: list[int],
|
|
||||||
loss_weights: list[int],
|
|
||||||
temp_str: str,
|
|
||||||
temp_weight: float,
|
|
||||||
) -> str:
|
|
||||||
"""Update model input with temporary string."""
|
|
||||||
if not temp_str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(processor)
|
|
||||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
|
||||||
input_ids.extend(temp_ids)
|
|
||||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
|
||||||
if temp_weight > 1e-6:
|
|
||||||
labels.extend(temp_ids)
|
|
||||||
else:
|
|
||||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
|
||||||
def render_qwen3_nothink_messages(
|
|
||||||
processor: Processor,
|
|
||||||
messages: list[Message],
|
|
||||||
tools: str | None = None,
|
|
||||||
is_generate: bool = False,
|
|
||||||
) -> ModelInput:
|
|
||||||
"""Render messages in the Qwen3 nothink template format.
|
|
||||||
|
|
||||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
|
||||||
"""
|
|
||||||
input_ids, labels, loss_weights = [], [], []
|
|
||||||
temp_str, temp_weight = "", 0.0
|
|
||||||
if tools:
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
if messages[0]["role"] == "system":
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n\n"
|
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
|
||||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
tools = json.loads(tools)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
|
||||||
|
|
||||||
if not isinstance(tools, list):
|
|
||||||
tools = [tools]
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
|
||||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
|
||||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
|
||||||
)
|
|
||||||
elif messages[0]["role"] == "system":
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
for turn_idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
elif message["role"] == "assistant":
|
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
||||||
for val_idx, content in enumerate(message["content"]):
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
elif content["type"] == "reasoning":
|
|
||||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
|
||||||
elif content["type"] == "tool_call":
|
|
||||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
|
||||||
temp_str += "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_call: ToolCall = json.loads(content["value"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
'<tool_call>\n{"name": "'
|
|
||||||
+ tool_call["name"]
|
|
||||||
+ '", "arguments": '
|
|
||||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
|
||||||
+ "}\n</tool_call>"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 1.0)
|
|
||||||
elif message["role"] == "tool":
|
|
||||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_start|>user"
|
|
||||||
|
|
||||||
temp_str += "\n<tool_response>\n"
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n</tool_response>"
|
|
||||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
if is_generate:
|
|
||||||
temp_str += "<|im_start|>assistant\n"
|
|
||||||
temp_weight = 0.0
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
attention_mask = [1] * len(input_ids)
|
|
||||||
return ModelInput(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
labels=labels,
|
|
||||||
loss_weights=loss_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
|
||||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
|
||||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generated_text (str): The generated text in the Qwen3 nothink template format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Message: The parsed message.
|
|
||||||
"""
|
|
||||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
|
||||||
content = []
|
|
||||||
last_end = 0
|
|
||||||
for match in pattern.finditer(generated_text):
|
|
||||||
start, end = match.span()
|
|
||||||
if start > last_end:
|
|
||||||
text = generated_text[last_end:start].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
tag_type = match.group(1)
|
|
||||||
tag_value = match.group(2).strip()
|
|
||||||
if tag_type == "thinking":
|
|
||||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
|
||||||
elif tag_type == "tool_call":
|
|
||||||
try:
|
|
||||||
json.loads(tag_value.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
|
||||||
|
|
||||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
|
||||||
|
|
||||||
last_end = end
|
|
||||||
|
|
||||||
if last_end < len(generated_text):
|
|
||||||
text = generated_text[last_end:].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
return Message(role="assistant", content=content)
|
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ....utils.constants import IGNORE_INDEX
|
||||||
|
from ....utils.helper import get_tokenizer
|
||||||
|
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
from ..rendering import RenderingPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_text_content(message: Message) -> str:
|
||||||
|
"""Concatenate text fields in a message."""
|
||||||
|
message_text = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
message_text += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
|
||||||
|
def _get_last_query_index(messages: list[Message]) -> int:
|
||||||
|
"""Find the last user query index, excluding wrapped tool responses."""
|
||||||
|
last_query_index = len(messages) - 1
|
||||||
|
for idx in range(len(messages) - 1, -1, -1):
|
||||||
|
message = messages[idx]
|
||||||
|
if message["role"] != "user":
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_text = ""
|
||||||
|
is_plain_text = True
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] != "text":
|
||||||
|
is_plain_text = False
|
||||||
|
break
|
||||||
|
user_text += content["value"]
|
||||||
|
|
||||||
|
if not is_plain_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
|
||||||
|
last_query_index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
return last_query_index
|
||||||
|
|
||||||
|
|
||||||
|
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
|
||||||
|
"""Split assistant message into text, reasoning and tool calls."""
|
||||||
|
text_content = ""
|
||||||
|
reasoning_content = ""
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
reasoning_content += content["value"]
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return text_content, reasoning_content, tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3").register("render_messages")
|
||||||
|
def render_qwen3_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
"""Render messages in the Qwen3 template format.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
|
||||||
|
"""
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
last_query_index = _get_last_query_index(messages)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
|
||||||
|
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
|
||||||
|
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
|
||||||
|
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
|
||||||
|
else:
|
||||||
|
temp_str += text_content
|
||||||
|
|
||||||
|
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||||
|
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
arguments = tool_call.get("arguments")
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
arguments_str = arguments
|
||||||
|
else:
|
||||||
|
arguments_str = json.dumps(arguments, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ arguments_str
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
if enable_thinking is False:
|
||||||
|
temp_str += "<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3").register("parse_message")
|
||||||
|
def parse_qwen3_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the Qwen3 template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "think":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ....utils.constants import IGNORE_INDEX
|
||||||
|
from ....utils.helper import get_tokenizer
|
||||||
|
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
from ..rendering import RenderingPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_text_content(message: Message) -> str:
|
||||||
|
"""Concatenate text fields in a message."""
|
||||||
|
message_text = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
message_text += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||||
|
def render_qwen3_nothink_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
"""Render messages in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
"""
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for val_idx, content in enumerate(message["content"]):
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
if enable_thinking:
|
||||||
|
raise ValueError("The qwen3_nothink template does not support thinking mode.")
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||||
|
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "thinking":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class Content(TypedDict):
|
class Content(TypedDict):
|
||||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
|
||||||
"""Type of the content."""
|
"""Type of the content."""
|
||||||
value: str
|
value: str
|
||||||
"""Value of the content."""
|
"""Value of the content."""
|
||||||
|
|||||||
Reference in New Issue
Block a user