support ms
Former-commit-id: fdd4f94f563110ef9f96ab4a7fd954def32e9785
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
@@ -63,6 +65,8 @@ def load_model_and_tokenizer(
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
@@ -228,3 +232,13 @@ def load_model_and_tokenizer(
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args):
|
||||
if os.environ.get('USE_MODELSCOPE_HUB', False) and not os.path.exists(model_args.model_name_or_path):
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, model_args.model_revision)
|
||||
except ImportError as e:
|
||||
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=True` but you have no modelscope sdk installed. '
|
||||
f'Please install it by `pip install modelscope -U`') from e
|
||||
|
||||
Reference in New Issue
Block a user