support BLOOM models

Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent 181c776b58
commit 693c049eac
16 changed files with 134 additions and 90 deletions

View File

@@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
LlamaConfig,
LlamaForCausalLM,
LlamaTokenizer,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
)
@@ -151,7 +151,7 @@ def load_pretrained(
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."
tokenizer = LlamaTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
@@ -173,13 +173,13 @@ def load_pretrained(
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# Load and prepare pretrained models (without valuehead).
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the llama weights are float16 type
torch_dtype=torch.float16, # the model weights are float16 type
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
@@ -245,7 +245,7 @@ def prepare_args(
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
logger.warning("We recommend enable fp16 mixed precision training.")
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")