patch modelscope
Former-commit-id: 8888cf53f040f5a2d8c0e59cddf79b252449bf58
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
@@ -23,8 +21,8 @@ try:
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
@@ -58,6 +56,8 @@ def load_model_and_tokenizer(
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -65,8 +65,6 @@ def load_model_and_tokenizer(
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
@@ -232,16 +230,3 @@ def load_model_and_tokenizer(
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args):
|
||||
if int(os.environ.get('USE_MODELSCOPE_HUB', '0')) and not os.path.exists(model_args.model_name_or_path):
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
revision = model_args.model_revision
|
||||
if revision == 'main':
|
||||
revision = 'master'
|
||||
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, revision)
|
||||
except ImportError as e:
|
||||
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=1` but you have no modelscope sdk installed. '
|
||||
f'Please install it by `pip install modelscope -U`') from e
|
||||
|
||||
Reference in New Issue
Block a user