support streaming data, fix #284 #274 #268

Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
hiyouga
2023-07-31 23:33:00 +08:00
parent 124f61b404
commit dd3f3e9749
28 changed files with 478 additions and 344 deletions

View File

@@ -1,12 +1,14 @@
import torch
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
r"""
@@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None:
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
model: "PreTrainedModel",
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel:
) -> "PreTrainedModel":
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
@@ -84,6 +94,9 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
@@ -92,11 +105,8 @@ def prepare_model_for_training(
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
new_output_layer = CastOutputToFloat(output_layer)
# adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot)
# (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819)
setattr(new_output_layer, "weight", output_layer.weight)
setattr(model, output_layer_name, new_output_layer)
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model