support BLOOM models
Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
@@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments
|
||||
)
|
||||
@@ -151,7 +151,7 @@ def load_pretrained(
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side="left"
|
||||
@@ -173,13 +173,13 @@ def load_pretrained(
|
||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # the llama weights are float16 type
|
||||
torch_dtype=torch.float16, # the model weights are float16 type
|
||||
**config_kwargs
|
||||
)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
@@ -245,7 +245,7 @@ def prepare_args(
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
|
||||
Reference in New Issue
Block a user