support ms

Former-commit-id: fdd4f94f563110ef9f96ab4a7fd954def32e9785
This commit is contained in:
yuze.zyz
2023-11-29 20:36:55 +08:00
parent 08d5340bd8
commit 9d125bf533
3 changed files with 188 additions and 1 deletions

View File

@@ -1,4 +1,6 @@
import math
import os
import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
@@ -63,6 +65,8 @@ def load_model_and_tokenizer(
"token": model_args.hf_hub_token
}
try_download_model_from_ms(model_args)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
@@ -228,3 +232,13 @@ def load_model_and_tokenizer(
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer
def try_download_model_from_ms(model_args):
if os.environ.get('USE_MODELSCOPE_HUB', False) and not os.path.exists(model_args.model_name_or_path):
try:
from modelscope import snapshot_download
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, model_args.model_revision)
except ImportError as e:
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=True` but you have no modelscope sdk installed. '
f'Please install it by `pip install modelscope -U`') from e