patch modelscope

Former-commit-id: 8888cf53f040f5a2d8c0e59cddf79b252449bf58
This commit is contained in:
hiyouga
2023-12-01 22:53:15 +08:00
parent ad9d866547
commit 72bbd5bdef
7 changed files with 312 additions and 222 deletions

View File

@@ -1,6 +1,4 @@
import math
import os
import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
@@ -23,8 +21,8 @@ try:
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
@@ -58,6 +56,8 @@ def load_model_and_tokenizer(
Support both training and inference.
"""
try_download_model_from_ms(model_args)
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
@@ -65,8 +65,6 @@ def load_model_and_tokenizer(
"token": model_args.hf_hub_token
}
try_download_model_from_ms(model_args)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
@@ -232,16 +230,3 @@ def load_model_and_tokenizer(
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer
def try_download_model_from_ms(model_args):
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and not os.path.exists(model_args.model_name_or_path):
try:
from modelscope import snapshot_download
revision = model_args.model_revision
if revision == 'main':
revision = 'master'
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, revision)
except ImportError as e:
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=1` but you have no modelscope sdk installed. '
f'Please install it by `pip install modelscope -U`') from e