From eae6f0b54188e9d297de907be9c06c27b3553613 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Sun, 5 Apr 2026 12:10:28 +0800 Subject: [PATCH] [model] gemma4 (#10346) --- .ai/CLAUDE.md | 105 ++++++++++ src/llamafactory/data/mm_plugin.py | 197 +++++++++++++++++- src/llamafactory/data/template.py | 49 +++++ src/llamafactory/data/tool_utils.py | 159 ++++++++++++++ src/llamafactory/extras/constants.py | 28 +++ src/llamafactory/model/model_utils/visual.py | 7 + .../train/hyper_parallel/workflow.py | 8 +- tests/data/test_mm_plugin.py | 30 ++- 8 files changed, 576 insertions(+), 7 deletions(-) create mode 100644 .ai/CLAUDE.md diff --git a/.ai/CLAUDE.md b/.ai/CLAUDE.md new file mode 100644 index 000000000..e211a26f3 --- /dev/null +++ b/.ai/CLAUDE.md @@ -0,0 +1,105 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +```bash +# Code style (auto-fix) +make style + +# Code quality check (no modifications) +make quality + +# Run all tests +make test + +# Run a single test file +WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py + +# Run tests matching a pattern +WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name" + +# License header check +make license + +# Build package +make build +``` + +The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available. + +## Architecture + +LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable: + +- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras` +- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils` + +Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`. + +### Entry Points + +CLI entry point is `llamafactory-cli` / `lmf` → `src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`. + +Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`. + +### Training Flow (v0) + +``` +run_exp() [tuner.py] + → read_args() → parse YAML/JSON config + → get_train_args() → produces typed argument dataclasses + → routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto + → optional: export_model() +``` + +Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml` + +### Configuration System + +All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses: +- `ModelArguments` — model/tokenizer selection, quantization +- `DataArguments` — datasets, templates, preprocessing +- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto) +- `TrainingArguments` — extends HuggingFace's `TrainingArguments` + +### Key Modules + +| Module | Purpose | +|--------|---------| +| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches | +| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches | +| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format | +| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling | +| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) | +| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure | +| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` | +| `src/llamafactory/extras/constants.py` | Enums and constants used across the project | + +### Adding Support for a New Model + +1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict +2. Add any necessary model patches in `src/llamafactory/model/patcher.py` +3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed + +### Distributed Training + +Multi-GPU automatically uses `torchrun`. Additional backends: +- **Ray:** Optional Ray cluster support +- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/` +- **Megatron-core:** `src/llamafactory/train/mca/` + +### Testing + +- `tests/` — v0 tests; `tests_v1/` — v1 tests +- Most training tests require GPU hardware +- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])` +- Always set `WANDB_DISABLED=true` when running tests + +### Code Style + +- Ruff for linting and formatting (line length 119, Google-style docstrings) +- Python 3.11+ syntax +- Double quotes for strings +- All new files must include Apache 2.0 license header (checked by `make license`) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index fe87f53e8..62a90ef65 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -607,6 +607,194 @@ class Gemma3nPlugin(Gemma3Plugin): return messages +@dataclass +class Gemma4Plugin(BasePlugin): + r"""Plugin for the Gemma4 multimodal model.""" + + @override + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": + r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation.""" + results, fps_per_video, durations, frames_indices = [], [], [], [] + for video in videos: + frames: list[ImageObject] = [] + if _check_video_is_nested_images(video): + frames = video + fps_per_video.append(kwargs.get("video_fps", 2.0)) + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + frames_indices.append(list(range(len(frames)))) + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + original_fps = float(video_stream.average_rate) + # for correctly calculate timestamps + frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if video_stream.duration is None: + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + durations.append(float(video_stream.duration * video_stream.time_base)) + + frames = self._regularize_images(frames, **kwargs)["images"] + results.append(frames) + + return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices} + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + image_processor = getattr(processor, "image_processor", None) + video_processor = getattr(processor, "video_processor", None) + feature_extractor = getattr(processor, "feature_extractor", None) + mm_inputs = {} + + if len(images) != 0: + regularized = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(regularized, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + video_metadata = [ + {"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} + for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"]) + ] + mm_inputs.update( + video_processor( + videos=video_data["videos"], + video_metadata=video_metadata, + return_tensors="pt", + return_metadata=True, + do_sample_frames=False, + ) + ) + + if len(audios) != 0: # only for gemma4n + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + + mm_inputs.update( + feature_extractor( + audios, + padding="max_length", + return_tensors="pt", + ) + ) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + + boi_token: str = getattr(processor, "boi_token") + eoi_token: str = getattr(processor, "eoi_token") + boa_token: str = getattr(processor, "boa_token") + eoa_token: str = getattr(processor, "eoa_token") + image_token: str = getattr(processor, "image_token") + video_token: str = getattr(processor, "video_token") + audio_token: str = getattr(processor, "audio_token") + + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + num_image_soft_tokens: list[int] = list( + mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images)) + ) + num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos))) + video_metadata = mm_inputs.get("video_metadata", []) + else: + num_image_soft_tokens = [1] * len(images) + num_video_soft_tokens = [1] * len(videos) + video_metadata = [None] * len(videos) + + audio_iter = iter(audios) + image_iter = iter(num_image_soft_tokens) + video_iter = iter(zip(num_video_soft_tokens, video_metadata)) + + for message in messages: + content = message["content"] + + while IMAGE_PLACEHOLDER in content: + n = next(image_iter) + content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1) + + while VIDEO_PLACEHOLDER in content: + num_soft_tokens_per_frame, metadata = next(video_iter) + if self.expand_mm_tokens: + timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps] + frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs] + video_str = " ".join(frame_strs) + else: + video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" + content = content.replace(VIDEO_PLACEHOLDER, video_str, 1) + + while AUDIO_PLACEHOLDER in content: + current_audio = next(audio_iter) + if self.expand_mm_tokens: + num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate) + audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}" + else: + audio_str = f"{boa_token}{audio_token}{eoa_token}" + + content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + # Pop metadata keys that must not be passed to the model. + for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata", + "_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"): + mm_inputs.pop(key, None) + + mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids) + + return mm_inputs + + @dataclass class InternVLPlugin(BasePlugin): @override @@ -1505,7 +1693,7 @@ class Qwen2VLPlugin(BasePlugin): else: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) original_fps = float(video_stream.average_rate) # for qwen3vl video timestamp calculation frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False @@ -1642,7 +1830,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) video_metadata = [ - {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} + {"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"]) ] mm_inputs.update( @@ -1683,7 +1871,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin): image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now - video_metadata = mm_inputs.get("video_metadata", {}) + video_metadata = mm_inputs.get("video_metadata", []) else: image_grid_thw = [None] * len(images) @@ -2206,8 +2394,9 @@ PLUGINS = { "base": BasePlugin, "ernie_vl": ErnieVLPlugin, "gemma3": Gemma3Plugin, - "glm4v": GLM4VPlugin, "gemma3n": Gemma3nPlugin, + "gemma4": Gemma4Plugin, + "glm4v": GLM4VPlugin, "intern_vl": InternVLPlugin, "kimi_vl": KimiVLPlugin, "llama4": Llama4Plugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index c4f21b8ee..5293a2f32 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -997,6 +997,55 @@ register_template( ) +register_template( + name="gemma4", + format_user=StringFormatter(slots=["<|turn>user\n{{content}}\n<|turn>model\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}\n"]), # default thought singal contained + format_observation=StringFormatter( + slots=["<|turn>tool\n{{content}}\n<|turn>model\n"] + ), # seem not consistent with the chattemplate + format_tools=ToolFormatter(tool_format="gemma4"), + format_function=FunctionFormatter(slots=["<|tool>{{content}}"], tool_format="gemma4"), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + default_system="You are a helpful assistant.", # important for thinking + thought_words=("<|channel>thought\n", ""), + replace_eos=True, + mm_plugin=get_mm_plugin( + "gemma4", + image_token="<|image|>", + video_token="<|video|>", + ), + template_class=ReasoningTemplate, +) + + +register_template( + name="gemma4n", + format_user=StringFormatter(slots=["<|turn>user\n{{content}}\n<|turn>model\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}\n"]), # default thought singal contained + format_observation=StringFormatter( + slots=["<|turn>tool\n{{content}}\n<|turn>model\n"] + ), + format_tools=ToolFormatter(tool_format="gemma4"), + format_function=FunctionFormatter(slots=["<|tool>{{content}}"], tool_format="gemma4"), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + default_system="You are a helpful assistant.", # important for thinking + thought_words=("<|channel>thought\n", ""), + replace_eos=True, + mm_plugin=get_mm_plugin( + "gemma4", + image_token="<|image|>", + video_token="<|video|>", + audio_token="<|audio|>", + ), + template_class=ReasoningTemplate, +) + + register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 8ca60cdad..4c3ec5ce6 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -209,6 +209,164 @@ class DefaultToolUtils(ToolUtils): return results +class Gemma4ToolUtils(ToolUtils): + r"""Gemma-4 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + def _format_parameters(properties: dict[str, Any]) -> str: + parts: list[str] = [] + for name, schema in properties.items(): + item_parts: list[str] = [] + if schema.get("description"): + item_parts.append(f'description:<|"|>{schema["description"]}<|"|>') + if schema.get("type"): + item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>') + parts.append(f"{name}:{{{','.join(item_parts)}}}") + + return ",".join(parts) + + declarations: list[str] = [] + for tool in tools: + function_data = tool.get("function", tool) if tool.get("type") == "function" else tool + declaration = ( + f"declaration:{function_data['name']}" + + "{" + + f'description:<|"|>{function_data.get("description", "")}<|"|>' + ) + + params = function_data.get("parameters") + if params: + param_parts: list[str] = [] + if params.get("properties"): + param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}") + + if params.get("required"): + required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"]) + param_parts.append(f"required:[{required_text}]") + + if params.get("type"): + param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>') + + declaration += f",parameters:{{{','.join(param_parts)}}}" + + response_declaration = function_data.get("response") + if response_declaration: + response_parts: list[str] = [] + if response_declaration.get("description"): + response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>') + + response_type = str(response_declaration.get("type", "")).upper() + + if response_type == "OBJECT": + response_parts.append(f'type:<|"|>{response_type}<|"|>') + + declaration += f",response:{{{','.join(response_parts)}}}" + + declarations.append(declaration + "}") + + return "\n".join(declarations) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}", re.DOTALL) + matches = re.findall(regex, content) + if not matches: + return content + + def _parse_arguments(arg_text: str) -> Any: + text = arg_text.strip() + if not text: + return {} + + # `function_formatter` writes dict arguments as `k:v,...` inside `{...}`. + # The extractor captures only the inner text, so re-wrap it to parse as JSON object. + object_like_text = "{" + text + "}" + # Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings. + normalized = re.sub( + r"<\|\"\|\>(.*?)<\|\"\|\>", + lambda m: json.dumps(m.group(1), ensure_ascii=False), + object_like_text, + flags=re.DOTALL, + ) + # Quote unquoted object keys so the payload can be parsed by json.loads. + normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized) + try: + return json.loads(normalized) + except json.JSONDecodeError: + pass + + try: + return json.loads(text) + except json.JSONDecodeError: + return text + + results: list[FunctionCall] = [] + for name, arg_block in matches: + parsed_arguments = _parse_arguments(arg_block) + if isinstance(parsed_arguments, str): + arguments = parsed_arguments + else: + arguments = json.dumps(parsed_arguments, ensure_ascii=False) + results.append(FunctionCall(name.strip(), arguments)) + + return results + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + def _format_argument(argument: Any, escape_keys: bool = True) -> str: + if isinstance(argument, str): + return f'<|"|>{argument}<|"|>' + + if isinstance(argument, bool): + return "true" if argument else "false" + + if isinstance(argument, dict): + items: list[str] = [] + for key in sorted(argument.keys()): + formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key) + formatted_value = _format_argument(argument[key], escape_keys=escape_keys) + items.append(f"{formatted_key}:{formatted_value}") + return "{" + ",".join(items) + "}" + + if isinstance(argument, (list, tuple)): + return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]" + + if argument is None: + return "null" + + return str(argument) + + function_texts: list[str] = [] + for function in functions: + name = function.name + raw_arguments = function.arguments + + try: + parsed_arguments = json.loads(raw_arguments) + except (TypeError, json.JSONDecodeError): + parsed_arguments = raw_arguments + + call_text = f"<|tool_call>call:{name}" + "{" + if isinstance(parsed_arguments, dict): + args_text = [] + for key in sorted(parsed_arguments.keys()): + value_text = _format_argument(parsed_arguments[key], escape_keys=False) + args_text.append(f"{key}:{value_text}") + + call_text += ",".join(args_text) + elif isinstance(parsed_arguments, str): + call_text += parsed_arguments + else: + call_text += _format_argument(parsed_arguments, escape_keys=False) + + call_text += "}" + function_texts.append(call_text) + + return "".join(function_texts) class GLM4ToolUtils(ToolUtils): r"""GLM-4 tool using template.""" @@ -723,6 +881,7 @@ class LFM2ToolUtils(ToolUtils): TOOLS = { "default": DefaultToolUtils(), + "gemma4": Gemma4ToolUtils(), "glm4": GLM4ToolUtils(), "llama3": Llama3ToolUtils(), "lfm2": LFM2ToolUtils(), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 87fb1f1c5..5c30ffd4b 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -865,6 +865,34 @@ register_model_group( ) +register_model_group( + models={ + "Gemma-4-26B-A4B-Thinking": { + DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it", + }, + "Gemma-4-31B-Thinking": { + DownloadSource.DEFAULT: "google/gemma-4-31B-it", + }, + }, + template="gemma4", + multimodal=True, +) + + +register_model_group( + models={ + "Gemma-4-E2B-Thinking": { + DownloadSource.DEFAULT: "google/gemma-4-E2B-it", + }, + "Gemma-4-E4B-Thinking": { + DownloadSource.DEFAULT: "google/gemma-4-E4B-it", + }, + }, + template="gemma4n", + multimodal=True, +) + + register_model_group( models={ "GLM-4-9B": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 3346d7fc1..df3eaa20c 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -219,6 +219,13 @@ _register_composite_model( ) +_register_composite_model( + model_type="gemma4", + vision_model_keys=["vision_tower", "audio_tower"], + lora_conflict_keys=["per_layer_projection_norm"], +) + + # copied from qwen2vl _register_composite_model( model_type="glm4v", diff --git a/src/llamafactory/train/hyper_parallel/workflow.py b/src/llamafactory/train/hyper_parallel/workflow.py index dd63901d8..85326ca09 100644 --- a/src/llamafactory/train/hyper_parallel/workflow.py +++ b/src/llamafactory/train/hyper_parallel/workflow.py @@ -48,7 +48,10 @@ def run_sft( "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." ) - from hyper_parallel.integration.llamafactory import HyperParallelArguments, HyperParallelTrainer # pylint: disable=C0415 + from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415 + HyperParallelArguments, + HyperParallelTrainer, + ) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] @@ -128,9 +131,10 @@ def run_sft( ) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import] from types import MethodType + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import] + trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator) trainer.add_callback(BAdamCallback) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 3187004aa..17c7f08a6 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -57,7 +57,7 @@ TEXT_MESSAGES = [ ] VIDEO_MESSAGES = [ - {"role": "user", "content": "