[trainer] update config (#7174)
Former-commit-id: 9f535d0e3c4ee3cd0f1b65218c2eee5d03f43c6f
This commit is contained in:
@@ -521,9 +521,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
num_audio_tokens = 0
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
@@ -1038,7 +1036,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
image = super()._preprocess_image(image, **kwargs)
|
||||
@@ -1124,7 +1122,10 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@@ -1134,8 +1135,6 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
@@ -1273,7 +1272,7 @@ PLUGINS = {
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_audio": Qwen2AudioPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"qwen2_vl": Qwen2VLPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
@@ -76,6 +77,12 @@ def _training_function(config: Dict[str, Any]) -> None:
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
try:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to destroy process group: {e}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
args = read_args(args)
|
||||
|
||||
Reference in New Issue
Block a user