format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -7,12 +8,14 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -29,7 +32,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False
add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -43,7 +46,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token
"token": model_args.hf_hub_token,
}
tokenizer = AutoTokenizer.from_pretrained(
@@ -51,7 +54,7 @@ def load_model_and_tokenizer(
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**config_kwargs
**config_kwargs,
)
patch_tokenizer(tokenizer)
@@ -61,7 +64,8 @@ def load_model_and_tokenizer(
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
@@ -69,7 +73,7 @@ def load_model_and_tokenizer(
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
@@ -89,7 +93,7 @@ def load_model_and_tokenizer(
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
**config_kwargs,
)
patch_model(model, tokenizer, model_args, is_trainable)
@@ -119,9 +123,11 @@ def load_model_and_tokenizer(
model.train()
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
logger.info(
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")