update examples

Former-commit-id: 369294b31c8a03a1cafcee83eb31a817007d3c49
This commit is contained in:
hiyouga
2024-04-15 22:14:34 +08:00
parent 86556b1c74
commit 276f2cb24e
11 changed files with 78 additions and 70 deletions

View File

@@ -1,8 +1,6 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from transformers.utils.versions import require_version
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
@@ -25,7 +23,6 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
self.can_generate = finetuning_args.stage == "sft"
engine_args = AsyncEngineArgs(
model=model_args.model_name_or_path,

View File

@@ -49,10 +49,6 @@ def is_starlette_available():
return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@@ -8,10 +8,10 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
@@ -74,6 +74,26 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
@@ -131,9 +151,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
@@ -158,6 +175,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
@@ -277,6 +295,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
raise ValueError("vLLM engine does not support RoPE scaling.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None:
model_args.device_map = {"": torch.device(model_args.export_device)}
@@ -298,6 +317,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"

View File

@@ -85,7 +85,9 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import numpy as np
from transformers.utils.versions import require_version
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
@@ -33,10 +32,6 @@ class ComputeMetrics:
r"""
Uses the model predictions to compute metrics.
"""
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}

View File

@@ -5,7 +5,6 @@ from transformers import Trainer
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
@@ -168,8 +167,6 @@ def _create_galore_optimizer(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore_torch")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
else: