support distributed quantized training

Former-commit-id: 74ff23a4f36f859f791f7b4be6f1877edc68f12f
This commit is contained in:
hiyouga
2023-06-06 17:39:41 +08:00
parent ac6f50dedf
commit bf5ad34196
7 changed files with 20 additions and 18 deletions

View File

@@ -38,7 +38,7 @@ from .config import (
)
from .other import (
get_logger,
get_main_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
@@ -53,7 +53,7 @@ require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
logger = get_logger(__name__)
logger = get_main_logger(__name__)
def _init_adapter(
@@ -190,9 +190,10 @@ def load_pretrained(
else:
raise NotImplementedError
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA
if not is_trainable:
config_kwargs["device_map"] = "auto"
# Load and prepare pretrained models (without valuehead).
@@ -288,7 +289,7 @@ def prepare_args(
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
, main_process_only=False)
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.