fix generating args

Former-commit-id: 52805a8441bd7b324bd89489de60f18f103c8e4c
This commit is contained in:
hiyouga
2023-06-13 01:33:56 +08:00
parent 4724ae3492
commit 6828f07d54
5 changed files with 20 additions and 16 deletions

View File

@@ -195,8 +195,6 @@ def load_pretrained(
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
else:
raise NotImplementedError
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@@ -273,8 +271,8 @@ def prepare_args(
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full":
raise ValueError("Quantization is incompatible with the full-parameter tuning.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@@ -358,7 +356,14 @@ def prepare_data(
)
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1]
if extension == "csv":
file_type = "csv"
elif extension == "json" or extension == "jsonl":
file_type = "json"
else:
file_type = "text"
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
@@ -366,7 +371,7 @@ def prepare_data(
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text",
file_type,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None