mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
Compare commits
4 Commits
9640f79ae5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7cb145f5d | ||
|
|
b53d7037c2 | ||
|
|
bf04ca6af8 | ||
|
|
762b480131 |
@@ -18,7 +18,7 @@ init_config:
|
|||||||
name: init_on_meta
|
name: init_on_meta
|
||||||
|
|
||||||
### data
|
### data
|
||||||
train_dataset: data/v1_sft_demo.yaml
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
### training
|
### training
|
||||||
output_dir: outputs/test_fsdp2
|
output_dir: outputs/test_fsdp2
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ dependencies = [
|
|||||||
"torch>=2.4.0",
|
"torch>=2.4.0",
|
||||||
"torchvision>=0.19.0",
|
"torchvision>=0.19.0",
|
||||||
"torchaudio>=2.4.0",
|
"torchaudio>=2.4.0",
|
||||||
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
|
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
|
||||||
"datasets>=2.16.0,<=4.0.0",
|
"datasets>=2.16.0,<=4.0.0",
|
||||||
"accelerate>=1.3.0,<=1.11.0",
|
"accelerate>=1.3.0,<=1.11.0",
|
||||||
"peft>=0.14.0,<=0.17.1",
|
"peft>=0.18.0,<=0.18.1",
|
||||||
"trl>=0.18.0,<=0.24.0",
|
"trl>=0.18.0,<=0.24.0",
|
||||||
"torchdata>=0.10.0,<=0.11.0",
|
"torchdata>=0.10.0,<=0.11.0",
|
||||||
# gui
|
# gui
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
deepspeed>=0.10.0,<=0.16.9
|
deepspeed>=0.10.0,<=0.18.4
|
||||||
|
|||||||
@@ -2159,6 +2159,40 @@ class LFMVLPlugin(BasePlugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class YoutuVLPlugin(BasePlugin):
|
||||||
|
r"""Plugin for Youtu-VL vision-language models."""
|
||||||
|
|
||||||
|
vision_bos_token: str = "<|vision_start|>"
|
||||||
|
vision_eos_token: str = "<|vision_end|>"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER, f"{self.vision_bos_token}{self.image_token}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER, f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
@@ -2181,6 +2215,7 @@ PLUGINS = {
|
|||||||
"qwen2_vl": Qwen2VLPlugin,
|
"qwen2_vl": Qwen2VLPlugin,
|
||||||
"qwen3_vl": Qwen3VLPlugin,
|
"qwen3_vl": Qwen3VLPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
|
"youtu_vl": YoutuVLPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2146,6 +2146,19 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="youtu_vl",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=["<|begin_of_text|>user\n{{content}}<|end_of_text|>\n<|begin_of_text|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|begin_of_text|>system\n{{content}}<|end_of_text|>\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|end_of_text|>"],
|
||||||
|
mm_plugin=get_mm_plugin(name="youtu_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yuan",
|
name="yuan",
|
||||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||||
|
|||||||
@@ -3375,6 +3375,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Youtu-VL-4B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "tencent/Youtu-VL-4B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-VL-4B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="youtu_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yuan2-2B-Chat": {
|
"Yuan2-2B-Chat": {
|
||||||
|
|||||||
@@ -41,12 +41,13 @@ class LoggerHandler(logging.Handler):
|
|||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
self.setLevel(logging.INFO)
|
self.setLevel(logging.INFO)
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||||
if os.path.exists(self.running_log):
|
try:
|
||||||
os.remove(self.running_log)
|
os.remove(self.running_log)
|
||||||
|
except OSError:
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
pass
|
||||||
|
|
||||||
def _write_log(self, log_entry: str) -> None:
|
def _write_log(self, log_entry: str) -> None:
|
||||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -94,10 +94,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.51.0,<=4.57.1")
|
check_version("transformers>=4.51.0,<=5.0.0")
|
||||||
check_version("datasets>=2.16.0,<=4.0.0")
|
check_version("datasets>=2.16.0,<=4.0.0")
|
||||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||||
check_version("peft>=0.14.0,<=0.17.1")
|
check_version("peft>=0.18.0,<=0.18.1")
|
||||||
check_version("trl>=0.18.0,<=0.24.0")
|
check_version("trl>=0.18.0,<=0.24.0")
|
||||||
|
|
||||||
|
|
||||||
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
|
|||||||
return torch.device(device)
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name() -> str:
|
||||||
|
r"""Get the name of available devices."""
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
device = "xpu"
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = "npu"
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
device = "mps"
|
||||||
|
elif is_torch_cuda_available():
|
||||||
|
device = "gpu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
r"""Get the torch device namespace for the available devices."""
|
||||||
|
device_name = get_device_name()
|
||||||
|
device_name = "cuda" if device_name == "gpu" else device_name
|
||||||
|
try:
|
||||||
|
return getattr(torch, device_name)
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
|
||||||
|
return torch.cuda
|
||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
def get_device_count() -> int:
|
||||||
r"""Get the number of available devices."""
|
r"""Get the number of available devices."""
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
|
|||||||
@@ -65,7 +65,9 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
|
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
|
metadata={
|
||||||
|
"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
interleave_probs: str | None = field(
|
interleave_probs: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -206,9 +206,6 @@ class BaseModelArguments:
|
|||||||
if self.model_name_or_path is None:
|
if self.model_name_or_path is None:
|
||||||
raise ValueError("Please provide `model_name_or_path`.")
|
raise ValueError("Please provide `model_name_or_path`.")
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
|
||||||
|
|
||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
|
|||||||
@@ -139,10 +139,6 @@ def _verify_model_args(
|
|||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
|
||||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
|
||||||
model_args.use_fast_tokenizer = False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
@@ -188,9 +184,7 @@ def _check_extra_dependencies(
|
|||||||
|
|
||||||
if training_args is not None:
|
if training_args is not None:
|
||||||
if training_args.deepspeed:
|
if training_args.deepspeed:
|
||||||
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
|
|
||||||
check_version("deepspeed", mandatory=True)
|
check_version("deepspeed", mandatory=True)
|
||||||
check_version("deepspeed>=0.10.0,<=0.16.9")
|
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
check_version("jieba", mandatory=True)
|
check_version("jieba", mandatory=True)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
@@ -40,56 +39,29 @@ else:
|
|||||||
class RayArguments:
|
class RayArguments:
|
||||||
r"""Arguments pertaining to the Ray training."""
|
r"""Arguments pertaining to the Ray training."""
|
||||||
|
|
||||||
ray_run_name: str | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
|
||||||
)
|
|
||||||
ray_storage_path: str = field(
|
|
||||||
default="./saves",
|
|
||||||
metadata={"help": "The storage path to save training results to"},
|
|
||||||
)
|
|
||||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
|
||||||
)
|
|
||||||
ray_num_workers: int = field(
|
ray_num_workers: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||||
)
|
)
|
||||||
resources_per_worker: dict | str = field(
|
|
||||||
default_factory=lambda: {"GPU": 1},
|
|
||||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
|
||||||
)
|
|
||||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
|
||||||
default="PACK",
|
|
||||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
|
||||||
)
|
|
||||||
ray_init_kwargs: dict | str | None = field(
|
ray_init_kwargs: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||||
)
|
)
|
||||||
|
master_addr: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The master address for init_process_group"},
|
||||||
|
)
|
||||||
|
master_port: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The master port for init_process_group"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.use_ray = use_ray()
|
self.use_ray = use_ray()
|
||||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
|
||||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
|
||||||
|
|
||||||
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
||||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||||
|
|
||||||
if self.ray_storage_filesystem is not None:
|
|
||||||
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
|
|
||||||
)
|
|
||||||
|
|
||||||
import pyarrow.fs as fs
|
|
||||||
|
|
||||||
if self.ray_storage_filesystem == "s3":
|
|
||||||
self.ray_storage_filesystem = fs.S3FileSystem()
|
|
||||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
|
||||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from transformers import (
|
|||||||
AutoModelForImageTextToText,
|
AutoModelForImageTextToText,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForTextToWaveform,
|
AutoModelForTextToWaveform,
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
@@ -166,11 +165,9 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||||
load_class = AutoModelForImageTextToText
|
load_class = AutoModelForImageTextToText
|
||||||
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
|
||||||
load_class = AutoModelForVision2Seq
|
|
||||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||||
load_class = AutoModelForSeq2SeqLM
|
load_class = AutoModelForSeq2SeqLM
|
||||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
|
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio-text for qwen omni
|
||||||
load_class = AutoModelForTextToWaveform
|
load_class = AutoModelForTextToWaveform
|
||||||
else:
|
else:
|
||||||
load_class = AutoModelForCausalLM
|
load_class = AutoModelForCausalLM
|
||||||
|
|||||||
@@ -57,6 +57,11 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) in ["youtu", "youtu_vl"]:
|
||||||
|
if model_args.flash_attn in (AttentionFunction.AUTO, AttentionFunction.SDPA):
|
||||||
|
logger.warning_rank0("Youtu-VL does not support SDPA, forcing eager attention.")
|
||||||
|
model_args.flash_attn = AttentionFunction.DISABLED
|
||||||
|
|
||||||
if model_args.flash_attn == AttentionFunction.AUTO:
|
if model_args.flash_attn == AttentionFunction.AUTO:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -85,6 +90,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
elif getattr(config, "model_type", None) == "kimi_vl":
|
elif getattr(config, "model_type", None) == "kimi_vl":
|
||||||
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||||
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
elif getattr(config, "model_type", None) == "youtu_vl":
|
||||||
|
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||||
|
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
if hasattr(config, "vision_config"):
|
||||||
|
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
if hasattr(config, "text_config"):
|
||||||
|
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||||
else:
|
else:
|
||||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
|
||||||
|
|||||||
@@ -374,7 +374,13 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_omni_moe_thinker",
|
model_type="qwen3_omni_moe_thinker",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
vision_model_keys=[
|
||||||
|
"visual.pos_embed",
|
||||||
|
"visual.patch_embed",
|
||||||
|
"visual.blocks",
|
||||||
|
"visual.deepstack_merger_list",
|
||||||
|
"audio_tower",
|
||||||
|
],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
|||||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||||
|
original_forward = model.forward
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
outputs = original_forward(*args, **kwargs)
|
||||||
|
if "loss" not in outputs and "labels" in kwargs:
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
labels = kwargs.get("labels")
|
||||||
|
if logits is not None and labels is not None:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||||
|
outputs["loss"] = loss
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
model.forward = MethodType(forward, model)
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
@@ -207,6 +227,9 @@ def patch_model(
|
|||||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||||
setattr(model_args, "disable_gradient_checkpointing", True)
|
setattr(model_args, "disable_gradient_checkpointing", True)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "youtu_vl":
|
||||||
|
patch_youtu_vl_model(model)
|
||||||
|
|
||||||
prepare_model_for_training(model, model_args)
|
prepare_model_for_training(model, model_args)
|
||||||
autocast_projector_dtype(model, model_args)
|
autocast_projector_dtype(model, model_args)
|
||||||
add_z3_leaf_module(model)
|
add_z3_leaf_module(model)
|
||||||
|
|||||||
@@ -103,7 +103,9 @@ class FixValueHeadModelCallback(TrainerCallback):
|
|||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
fix_valuehead_checkpoint(
|
fix_valuehead_checkpoint(
|
||||||
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
model=kwargs.pop("model"),
|
||||||
|
output_dir=output_dir,
|
||||||
|
safe_serialization=getattr(args, "save_safetensors", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,7 +139,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
if isinstance(model, PeftModel):
|
if isinstance(model, PeftModel):
|
||||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
model.save_pretrained(pissa_init_dir, safe_serialization=getattr(args, "save_safetensors", True))
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -155,11 +157,11 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
if isinstance(model, PeftModel):
|
if isinstance(model, PeftModel):
|
||||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
|
model.save_pretrained(pissa_backup_dir, safe_serialization=getattr(args, "save_safetensors", True))
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
pissa_convert_dir,
|
pissa_convert_dir,
|
||||||
safe_serialization=args.save_safetensors,
|
safe_serialization=getattr(args, "save_safetensors", True),
|
||||||
path_initial_model_for_weight_conversion=pissa_init_dir,
|
path_initial_model_for_weight_conversion=pissa_init_dir,
|
||||||
)
|
)
|
||||||
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def run_ppo(
|
|||||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
|
||||||
|
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
if self.args.save_safetensors:
|
if getattr(self.args, "save_safetensors", True):
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
ptrs = defaultdict(list)
|
ptrs = defaultdict(list)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def run_rm(
|
|||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
fix_valuehead_checkpoint(model, training_args.output_dir, getattr(training_args, "save_safetensors", True))
|
||||||
|
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,6 +33,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||||
|
from ..extras.misc import get_device_name
|
||||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||||
@@ -49,15 +49,15 @@ if is_apollo_available():
|
|||||||
|
|
||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
import ray
|
import ray
|
||||||
from ray.train import RunConfig, ScalingConfig
|
from ray.util.placement_group import PlacementGroup, placement_group
|
||||||
from ray.train.torch import TorchTrainer
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
from ..hparams import DataArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -807,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
|||||||
return swanlab_callback
|
return swanlab_callback
|
||||||
|
|
||||||
|
|
||||||
def get_ray_trainer(
|
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
|
||||||
training_function: Callable,
|
r"""Get the Ray placement group for distributed training."""
|
||||||
train_loop_config: dict[str, Any],
|
bundle = {"CPU": 10}
|
||||||
ray_args: "RayArguments",
|
device_name = get_device_name().upper()
|
||||||
) -> "TorchTrainer":
|
if device_name != "CPU":
|
||||||
if not ray_args.use_ray:
|
bundle[device_name] = 1
|
||||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
bundles = [bundle for _ in range(num_workers)]
|
||||||
|
pg = placement_group(bundles, strategy="PACK")
|
||||||
|
|
||||||
if ray_args.ray_init_kwargs is not None:
|
return pg, bundle
|
||||||
ray.init(**ray_args.ray_init_kwargs)
|
|
||||||
|
|
||||||
if ray_args.ray_storage_filesystem is not None:
|
|
||||||
# this means we are using s3/gcs
|
|
||||||
storage_path = ray_args.ray_storage_path
|
|
||||||
else:
|
|
||||||
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
|
||||||
|
|
||||||
trainer = TorchTrainer(
|
def get_ray_remote_config_for_worker(
|
||||||
training_function,
|
placement_group: "PlacementGroup",
|
||||||
train_loop_config=train_loop_config,
|
bundle_idx: int,
|
||||||
scaling_config=ScalingConfig(
|
rank: int,
|
||||||
num_workers=ray_args.ray_num_workers,
|
world_size: int,
|
||||||
resources_per_worker=ray_args.resources_per_worker,
|
master_addr: str,
|
||||||
placement_strategy=ray_args.placement_strategy,
|
master_port: str,
|
||||||
use_gpu=True,
|
env: dict[str, str] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
r"""Get the remote config for a Ray worker."""
|
||||||
|
env_vars = {
|
||||||
|
"RANK": str(rank),
|
||||||
|
"WORLD_SIZE": str(world_size),
|
||||||
|
"MASTER_ADDR": master_addr,
|
||||||
|
"MASTER_PORT": master_port,
|
||||||
|
"TORCHELASTIC_USE_AGENT_STORE": "False",
|
||||||
|
}
|
||||||
|
env.update(env_vars)
|
||||||
|
|
||||||
|
remote_config = {
|
||||||
|
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=bundle_idx,
|
||||||
),
|
),
|
||||||
run_config=RunConfig(
|
"runtime_env": {"env_vars": env},
|
||||||
name=ray_args.ray_run_name,
|
"num_cpus": 10,
|
||||||
storage_filesystem=ray_args.ray_storage_filesystem,
|
}
|
||||||
storage_path=storage_path,
|
|
||||||
),
|
device_name = get_device_name()
|
||||||
)
|
if device_name == "gpu":
|
||||||
return trainer
|
remote_config["num_gpus"] = 1
|
||||||
|
elif device_name == "npu":
|
||||||
|
remote_config["resources"] = {"NPU": 1}
|
||||||
|
|
||||||
|
return remote_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_ray_head_node_ip() -> str:
|
||||||
|
r"""Get the IP address of the Ray head node."""
|
||||||
|
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||||
|
return head_ip
|
||||||
|
|
||||||
|
|
||||||
|
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
|
||||||
|
r"""Sort the placement group bundles by their node IP addresses."""
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def _get_node_ip():
|
||||||
|
return ray.util.get_node_ip_address().strip("[]")
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for bundle_idx in range(placement_group.bundle_count):
|
||||||
|
task = _get_node_ip.options(
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_bundle_index=bundle_idx,
|
||||||
|
),
|
||||||
|
).remote()
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
bundle_ips = ray.get(tasks)
|
||||||
|
bundle_node_ip_list = list(enumerate(bundle_ips))
|
||||||
|
|
||||||
|
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
|
||||||
|
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
|
||||||
|
|
||||||
|
if master_addr is not None:
|
||||||
|
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
|
||||||
|
if preferred_indices:
|
||||||
|
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
|
||||||
|
sorted_bundle_indices = preferred_indices + remaining
|
||||||
|
|
||||||
|
return sorted_bundle_indices
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
|
|||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
@@ -34,12 +34,17 @@ from .ppo import run_ppo
|
|||||||
from .pt import run_pt
|
from .pt import run_pt
|
||||||
from .rm import run_rm
|
from .rm import run_rm
|
||||||
from .sft import run_sft
|
from .sft import run_sft
|
||||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
from .trainer_utils import (
|
||||||
|
get_placement_group,
|
||||||
|
get_ray_head_node_ip,
|
||||||
|
get_ray_remote_config_for_worker,
|
||||||
|
get_swanlab_callback,
|
||||||
|
sort_placement_group_by_node_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
import ray
|
import ray
|
||||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
|
|||||||
ray_args = get_ray_args(args)
|
ray_args = get_ray_args(args)
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
if ray_args.use_ray:
|
if ray_args.use_ray:
|
||||||
callbacks.append(RayTrainReportCallback())
|
_ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
|
||||||
trainer = get_ray_trainer(
|
|
||||||
training_function=_training_function,
|
|
||||||
train_loop_config={"args": args, "callbacks": callbacks},
|
|
||||||
ray_args=ray_args,
|
|
||||||
)
|
|
||||||
trainer.fit()
|
|
||||||
else:
|
else:
|
||||||
_training_function(config={"args": args, "callbacks": callbacks})
|
_training_function(config={"args": args, "callbacks": callbacks})
|
||||||
|
|
||||||
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
|||||||
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
||||||
f.write(template.get_ollama_modelfile(tokenizer))
|
f.write(template.get_ollama_modelfile(tokenizer))
|
||||||
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self):
|
||||||
|
self._setup_env_visible_devices()
|
||||||
|
|
||||||
|
local_rank = os.environ.get("LOCAL_RANK", "0")
|
||||||
|
get_torch_device().set_device(int(local_rank))
|
||||||
|
|
||||||
|
def _setup_env_visible_devices(self) -> None:
|
||||||
|
RAY_NOSET_VISIBLE_DEVICES_LIST = [
|
||||||
|
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
|
||||||
|
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
|
||||||
|
]
|
||||||
|
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
|
||||||
|
if is_ray_noset_visible_devices:
|
||||||
|
device_name = get_device_name().upper()
|
||||||
|
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
|
||||||
|
os.environ["LOCAL_RANK"] = local_rank
|
||||||
|
else:
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
def _training_function(self, config: dict[str, Any]) -> None:
|
||||||
|
_training_function(config)
|
||||||
|
|
||||||
|
|
||||||
|
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
|
||||||
|
num_workers = ray_args.ray_num_workers
|
||||||
|
master_addr = ray_args.master_addr
|
||||||
|
master_port = ray_args.master_port
|
||||||
|
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
|
||||||
|
|
||||||
|
# initialize ray
|
||||||
|
if not ray.is_initialized():
|
||||||
|
if ray_args.ray_init_kwargs is not None:
|
||||||
|
ray.init(**ray_args.ray_init_kwargs)
|
||||||
|
else:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
# verify resources
|
||||||
|
device_name = get_device_name().upper()
|
||||||
|
total_devices = int(ray.cluster_resources().get(device_name, 0))
|
||||||
|
if num_workers > total_devices:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# verify master_addr
|
||||||
|
if master_addr is None:
|
||||||
|
master_addr = get_ray_head_node_ip()
|
||||||
|
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
|
||||||
|
else:
|
||||||
|
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
|
||||||
|
if master_addr not in nodes:
|
||||||
|
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
|
||||||
|
|
||||||
|
# create placementgroup for resource management
|
||||||
|
pg, bundle = get_placement_group(total_devices)
|
||||||
|
ray.get(pg.ready())
|
||||||
|
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
|
||||||
|
|
||||||
|
# get sorted_bundle_indices
|
||||||
|
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
|
||||||
|
|
||||||
|
# get master port
|
||||||
|
if master_port is None:
|
||||||
|
master_port = find_available_port()
|
||||||
|
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
|
||||||
|
master_port = str(master_port)
|
||||||
|
|
||||||
|
# backing up environment variables
|
||||||
|
current_env = dict(os.environ.items())
|
||||||
|
|
||||||
|
# launch workers
|
||||||
|
RayWorker = ray.remote(Worker)
|
||||||
|
workers = []
|
||||||
|
for rank in range(num_workers):
|
||||||
|
remote_config = get_ray_remote_config_for_worker(
|
||||||
|
placement_group=pg,
|
||||||
|
bundle_idx=sorted_bundle_indices[rank],
|
||||||
|
rank=rank,
|
||||||
|
world_size=num_workers,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
env=current_env,
|
||||||
|
)
|
||||||
|
worker = RayWorker.options(**remote_config).remote()
|
||||||
|
workers.append(worker)
|
||||||
|
|
||||||
|
ray.get([worker._training_function.remote(config=config) for worker in workers])
|
||||||
|
ray.shutdown()
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -149,14 +148,7 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
devices_str = ",".join(str(i) for i in range(required))
|
devices_str = ",".join(str(i) for i in range(required))
|
||||||
|
|
||||||
monkeypatch.setenv(env_key, devices_str)
|
monkeypatch.setenv(env_key, devices_str)
|
||||||
|
monkeypatch.syspath_prepend(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
# add project root dir to path for mp run
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
if project_root not in sys.path:
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
|
|
||||||
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
|
||||||
|
|
||||||
else: # non-distributed test
|
else: # non-distributed test
|
||||||
if old_value:
|
if old_value:
|
||||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_dataset_module
|
from llamafactory.train.test_utils import load_dataset_module
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +53,12 @@ def test_feedback_data(num_samples: int):
|
|||||||
for index in indexes:
|
for index in indexes:
|
||||||
messages = original_data["messages"][index]
|
messages = original_data["messages"][index]
|
||||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||||
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||||
|
|
||||||
|
prompt_len = len(ref_prompt_ids)
|
||||||
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||||
assert train_dataset["labels"][index] == ref_labels
|
assert train_dataset["labels"][index] == ref_labels
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_dataset_module
|
from llamafactory.train.test_utils import load_dataset_module
|
||||||
|
|
||||||
|
|
||||||
@@ -63,13 +64,21 @@ def test_pairwise_data(num_samples: int):
|
|||||||
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
||||||
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
||||||
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
||||||
|
|
||||||
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
||||||
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
|
ref_chosen_prompt_ids = ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)
|
||||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
|
||||||
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
||||||
rejected_prompt_len = len(
|
ref_rejected_prompt_ids = ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||||
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
|
||||||
)
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_chosen_input_ids = ref_chosen_input_ids["input_ids"]
|
||||||
|
ref_rejected_input_ids = ref_rejected_input_ids["input_ids"]
|
||||||
|
ref_chosen_prompt_ids = ref_chosen_prompt_ids["input_ids"]
|
||||||
|
ref_rejected_prompt_ids = ref_rejected_prompt_ids["input_ids"]
|
||||||
|
|
||||||
|
chosen_prompt_len = len(ref_chosen_prompt_ids)
|
||||||
|
rejected_prompt_len = len(ref_rejected_prompt_ids)
|
||||||
|
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||||
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
||||||
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
||||||
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from datasets import load_dataset
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_dataset_module
|
from llamafactory.train.test_utils import load_dataset_module
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +60,16 @@ def test_supervised_single_turn(num_samples: int):
|
|||||||
{"role": "assistant", "content": original_data["output"][index]},
|
{"role": "assistant", "content": original_data["output"][index]},
|
||||||
]
|
]
|
||||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||||
|
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||||
|
|
||||||
|
prompt_len = len(ref_prompt_ids)
|
||||||
|
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||||
|
assert train_dataset["labels"][index] == ref_label_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@@ -73,6 +83,10 @@ def test_supervised_multi_turn(num_samples: int):
|
|||||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||||
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
|
||||||
|
# cannot test the label ids in multi-turn case
|
||||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -86,9 +100,12 @@ def test_supervised_train_on_prompt(num_samples: int):
|
|||||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||||
assert train_dataset["input_ids"][index] == ref_ids
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
assert train_dataset["labels"][index] == ref_ids
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
|
||||||
|
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||||
|
assert train_dataset["labels"][index] == ref_input_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@@ -103,7 +120,13 @@ def test_supervised_mask_history(num_samples: int):
|
|||||||
for index in indexes:
|
for index in indexes:
|
||||||
messages = original_data["messages"][index]
|
messages = original_data["messages"][index]
|
||||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||||
|
|
||||||
|
prompt_len = len(ref_prompt_ids)
|
||||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||||
assert train_dataset["labels"][index] == ref_label_ids
|
assert train_dataset["labels"][index] == ref_label_ids
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import pytest
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_dataset_module
|
from llamafactory.train.test_utils import load_dataset_module
|
||||||
|
|
||||||
|
|
||||||
@@ -55,8 +56,13 @@ def test_unsupervised_data(num_samples: int):
|
|||||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
messages = original_data["messages"][index]
|
messages = original_data["messages"][index]
|
||||||
ref_ids = ref_tokenizer.apply_chat_template(messages)
|
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
ref_prompt_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||||
ref_labels = ref_ids[len(ref_input_ids) :]
|
|
||||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
ref_input_ids = ref_input_ids["input_ids"]
|
||||||
|
ref_prompt_ids = ref_prompt_ids["input_ids"]
|
||||||
|
|
||||||
|
ref_labels = ref_input_ids[len(ref_prompt_ids) :]
|
||||||
|
assert train_dataset["input_ids"][index] == ref_prompt_ids
|
||||||
assert train_dataset["labels"][index] == ref_labels
|
assert train_dataset["labels"][index] == ref_labels
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
from transformers import AutoConfig, AutoModelForImageTextToText
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||||
@@ -82,7 +82,7 @@ def test_multimodal_collator():
|
|||||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = AutoModelForVision2Seq.from_config(config)
|
model = AutoModelForImageTextToText.from_config(config)
|
||||||
|
|
||||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
template=template,
|
template=template,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
from llamafactory.data.template import parse_template
|
from llamafactory.data.template import parse_template
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.hparams import DataArguments
|
from llamafactory.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +66,6 @@ def _check_template(
|
|||||||
template_name: str,
|
template_name: str,
|
||||||
prompt_str: str,
|
prompt_str: str,
|
||||||
answer_str: str,
|
answer_str: str,
|
||||||
use_fast: bool,
|
|
||||||
messages: list[dict[str, str]] = MESSAGES,
|
messages: list[dict[str, str]] = MESSAGES,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""Check template.
|
r"""Check template.
|
||||||
@@ -75,13 +75,15 @@ def _check_template(
|
|||||||
template_name: the template name.
|
template_name: the template name.
|
||||||
prompt_str: the string corresponding to the prompt part.
|
prompt_str: the string corresponding to the prompt part.
|
||||||
answer_str: the string corresponding to the answer part.
|
answer_str: the string corresponding to the answer part.
|
||||||
use_fast: whether to use fast tokenizer.
|
|
||||||
messages: the list of messages.
|
messages: the list of messages.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
|
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
|
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
|
||||||
|
if is_transformers_version_greater_than("5.0.0"):
|
||||||
|
content_ids = content_ids["input_ids"]
|
||||||
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
|
||||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
||||||
assert content_str == prompt_str + answer_str
|
assert content_str == prompt_str + answer_str
|
||||||
@@ -90,9 +92,8 @@ def _check_template(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_encode_oneturn():
|
||||||
def test_encode_oneturn(use_fast: bool):
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
@@ -106,9 +107,8 @@ def test_encode_oneturn(use_fast: bool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_encode_multiturn():
|
||||||
def test_encode_multiturn(use_fast: bool):
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||||
prompt_str_1 = (
|
prompt_str_1 = (
|
||||||
@@ -128,11 +128,10 @@ def test_encode_multiturn(use_fast: bool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
|
||||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||||
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
def test_reasoning_encode_oneturn(cot_messages: bool, enable_thinking: bool):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
||||||
@@ -155,11 +154,10 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
|
||||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||||
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
|
||||||
@@ -185,10 +183,9 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_jinja_template():
|
||||||
def test_jinja_template(use_fast: bool):
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
|
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
|
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
|
||||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||||
@@ -222,8 +219,7 @@ def test_get_stop_token_ids():
|
|||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_gemma_template():
|
||||||
def test_gemma_template(use_fast: bool):
|
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||||
@@ -231,13 +227,12 @@ def test_gemma_template(use_fast: bool):
|
|||||||
"<start_of_turn>model\n"
|
"<start_of_turn>model\n"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_gemma2_template():
|
||||||
def test_gemma2_template(use_fast: bool):
|
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||||
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||||
@@ -245,13 +240,12 @@ def test_gemma2_template(use_fast: bool):
|
|||||||
"<start_of_turn>model\n"
|
"<start_of_turn>model\n"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||||
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
|
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_llama3_template():
|
||||||
def test_llama3_template(use_fast: bool):
|
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
|
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
|
||||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
|
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
|
||||||
@@ -259,14 +253,11 @@ def test_llama3_template(use_fast: bool):
|
|||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
|
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
|
||||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize(
|
def test_llama4_template():
|
||||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
|
|
||||||
)
|
|
||||||
def test_llama4_template(use_fast: bool):
|
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
|
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
|
||||||
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
|
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
|
||||||
@@ -274,18 +265,11 @@ def test_llama4_template(use_fast: bool):
|
|||||||
"<|header_start|>assistant<|header_end|>\n\n"
|
"<|header_start|>assistant<|header_end|>\n\n"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
|
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
|
||||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"use_fast",
|
|
||||||
[
|
|
||||||
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
|
|
||||||
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_phi4_template(use_fast: bool):
|
def test_phi4_template():
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
|
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
|
||||||
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
|
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
|
||||||
@@ -293,13 +277,12 @@ def test_phi4_template(use_fast: bool):
|
|||||||
"<|im_start|>assistant<|im_sep|>"
|
"<|im_start|>assistant<|im_sep|>"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
|
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
|
||||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
def test_qwen2_5_template():
|
||||||
def test_qwen2_5_template(use_fast: bool):
|
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||||
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||||
@@ -308,13 +291,12 @@ def test_qwen2_5_template(use_fast: bool):
|
|||||||
"<|im_start|>assistant\n"
|
"<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
||||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
|
||||||
@pytest.mark.parametrize("cot_messages", [True, False])
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
def test_qwen3_template(cot_messages: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||||
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||||
@@ -328,12 +310,12 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
|||||||
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
|
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
|
||||||
messages = MESSAGES_WITH_THOUGHT
|
messages = MESSAGES_WITH_THOUGHT
|
||||||
|
|
||||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
|
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_parse_llama3_template():
|
def test_parse_llama3_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
|
||||||
template = parse_template(tokenizer)
|
template = parse_template(tokenizer)
|
||||||
assert template.format_user.slots == [
|
assert template.format_user.slots == [
|
||||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
@@ -348,7 +330,7 @@ def test_parse_llama3_template():
|
|||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||||
def test_parse_qwen_template():
|
def test_parse_qwen_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
||||||
template = parse_template(tokenizer)
|
template = parse_template(tokenizer)
|
||||||
assert template.__class__.__name__ == "Template"
|
assert template.__class__.__name__ == "Template"
|
||||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||||
@@ -361,7 +343,7 @@ def test_parse_qwen_template():
|
|||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
|
||||||
def test_parse_qwen3_template():
|
def test_parse_qwen3_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||||
template = parse_template(tokenizer)
|
template = parse_template(tokenizer)
|
||||||
assert template.__class__.__name__ == "ReasoningTemplate"
|
assert template.__class__.__name__ == "ReasoningTemplate"
|
||||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
from safetensors.torch import load_file
|
||||||
|
from transformers import AutoConfig, AutoModelForImageTextToText
|
||||||
|
|
||||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.hparams import FinetuningArguments, ModelArguments
|
from llamafactory.hparams import FinetuningArguments, ModelArguments
|
||||||
@@ -36,7 +37,7 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
|
|||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = AutoModelForVision2Seq.from_config(config)
|
model = AutoModelForImageTextToText.from_config(config)
|
||||||
|
|
||||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
@@ -56,7 +57,7 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
|
|||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = AutoModelForVision2Seq.from_config(config)
|
model = AutoModelForImageTextToText.from_config(config)
|
||||||
|
|
||||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
|
||||||
trainable_params, frozen_params = set(), set()
|
trainable_params, frozen_params = set(), set()
|
||||||
@@ -86,13 +87,14 @@ def test_visual_model_save_load():
|
|||||||
finetuning_args = FinetuningArguments(finetuning_type="full")
|
finetuning_args = FinetuningArguments(finetuning_type="full")
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = AutoModelForVision2Seq.from_config(config)
|
model = AutoModelForImageTextToText.from_config(config)
|
||||||
|
|
||||||
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
|
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
|
||||||
|
model.to_empty(device="cpu")
|
||||||
loaded_model_weight = dict(model.named_parameters())
|
loaded_model_weight = dict(model.named_parameters())
|
||||||
|
|
||||||
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False)
|
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=True)
|
||||||
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False)
|
saved_model_weight = load_file(os.path.join("output", "qwen2_vl", "model.safetensors"))
|
||||||
|
|
||||||
if is_transformers_version_greater_than("4.52.0"):
|
if is_transformers_version_greater_than("4.52.0"):
|
||||||
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
|
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.5.105
|
0.9.5.106
|
||||||
|
|||||||
@@ -23,6 +23,13 @@ from llamafactory.v1.core.utils.rendering import Renderer
|
|||||||
from llamafactory.v1.utils.types import Processor
|
from llamafactory.v1.utils.types import Processor
|
||||||
|
|
||||||
|
|
||||||
|
def _get_input_ids(inputs: list | dict) -> list:
|
||||||
|
if not isinstance(inputs, list):
|
||||||
|
return inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
HF_MESSAGES = [
|
HF_MESSAGES = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "What is LLM?"},
|
{"role": "user", "content": "What is LLM?"},
|
||||||
@@ -81,15 +88,15 @@ def test_chatml_rendering():
|
|||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
|
||||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
|
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True))
|
||||||
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
|
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
|
||||||
assert v1_inputs["input_ids"] == hf_inputs
|
assert v1_inputs["input_ids"] == hf_inputs
|
||||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||||
|
|
||||||
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
|
hf_inputs_part = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False))
|
||||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
|
hf_inputs_full = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False))
|
||||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
|
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
|
||||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||||
@@ -124,17 +131,21 @@ def test_qwen3_nothink_rendering():
|
|||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
|
||||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
hf_inputs = _get_input_ids(
|
||||||
|
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||||
|
)
|
||||||
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
||||||
assert v1_inputs["input_ids"] == hf_inputs
|
assert v1_inputs["input_ids"] == hf_inputs
|
||||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||||
|
|
||||||
hf_inputs_part = tokenizer.apply_chat_template(
|
hf_inputs_part = _get_input_ids(
|
||||||
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
|
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False)
|
||||||
|
)
|
||||||
|
hf_inputs_full = _get_input_ids(
|
||||||
|
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||||
)
|
)
|
||||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
|
||||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
||||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||||
@@ -187,7 +198,7 @@ def test_qwen3_nothink_rendering_remote(num_samples: int):
|
|||||||
def test_process_sft_samples():
|
def test_process_sft_samples():
|
||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
|
||||||
|
|
||||||
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
|
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
|
||||||
model_inputs = renderer.process_samples(samples)
|
model_inputs = renderer.process_samples(samples)
|
||||||
@@ -200,7 +211,7 @@ def test_process_sft_samples():
|
|||||||
def test_process_dpo_samples():
|
def test_process_dpo_samples():
|
||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES)
|
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
|
||||||
|
|
||||||
samples = [
|
samples = [
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user