fix rm server
Former-commit-id: 81bc1638682a9fd01518f9f25250a6b584d2a9e6
This commit is contained in:
@@ -27,7 +27,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
if getattr(model, "_no_split_modules", None) is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||
|
||||
Reference in New Issue
Block a user