Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: ce5cb0f897eebe32a1c2c0a78fe1b0267e4b6d9d
This commit is contained in:
@@ -51,7 +51,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
api_key = os.environ.get("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
|
||||
@@ -65,12 +65,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
|
||||
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
@@ -78,14 +79,16 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if stop is not None:
|
||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
@@ -94,6 +97,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if not generating_args["do_sample"]:
|
||||
generating_args.pop("temperature", None)
|
||||
generating_args.pop("top_p", None)
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
||||
@@ -89,43 +89,35 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
use_beam_search = self.generating_args["num_beams"] > 1
|
||||
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
|
||||
generating_args = self.generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
)
|
||||
)
|
||||
|
||||
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
|
||||
if max_length:
|
||||
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
max_tokens = max_new_tokens
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=generating_args["num_return_sequences"],
|
||||
repetition_penalty=generating_args["repetition_penalty"],
|
||||
temperature=generating_args["temperature"],
|
||||
top_p=generating_args["top_p"],
|
||||
top_k=generating_args["top_k"],
|
||||
use_beam_search=generating_args["num_beams"] > 1,
|
||||
length_penalty=generating_args["length_penalty"],
|
||||
n=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=generating_args["max_new_tokens"],
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class LogCallback(TrainerCallback):
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
|
||||
@@ -58,7 +58,7 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
|
||||
@@ -21,6 +21,9 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
last = scalars[0]
|
||||
smoothed = []
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
@@ -32,6 +35,9 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
@@ -51,6 +57,9 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -44,6 +45,10 @@ def patch_config(
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
@@ -57,7 +62,7 @@ def patch_config(
|
||||
logger.info("Using KV cache for faster generation.")
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn)
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
@@ -52,4 +52,4 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||
elif attn_implementation == "sdpa":
|
||||
logger.info("Using torch SDPA for faster training and inference.")
|
||||
else:
|
||||
logger.info("Using vanilla Attention implementation.")
|
||||
logger.info("Using vanilla attention implementation.")
|
||||
|
||||
@@ -71,12 +71,12 @@ def create_web_demo() -> gr.Blocks:
|
||||
|
||||
|
||||
def run_web_ui() -> None:
|
||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
Reference in New Issue
Block a user