Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: e82f527ea583a7e99a25a06c7fe7b03c1dc2ebb9
This commit is contained in:
@@ -6,7 +6,7 @@ from typing_extensions import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
create_chat_completion_response,
|
||||
create_score_evaluation_response,
|
||||
@@ -22,7 +22,7 @@ from .protocol import (
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
if is_fastapi_available():
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
@@ -3,7 +3,8 @@ import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.packages import is_fastapi_availble
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_fastapi_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
@@ -20,7 +21,7 @@ from .protocol import (
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
if is_fastapi_available():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
@@ -39,6 +41,8 @@ ROLE_MAPPING = {
|
||||
|
||||
|
||||
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from .aligner import align_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum, merge_dataset
|
||||
from .utils import merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -61,8 +61,6 @@ def load_single_dataset(
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
||||
checksum(data_files, dataset_attr.file_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ class DatasetAttr:
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
""" extra configs """
|
||||
file_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
@@ -99,7 +98,6 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
|
||||
@@ -308,7 +308,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
|
||||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -26,21 +25,6 @@ class Role(str, Enum):
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
|
||||
@@ -139,13 +139,15 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -320,14 +320,14 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeek-MoE-236B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
"DeepSeek-MoE-236B": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
|
||||
},
|
||||
"DeepSeek-MoE-236B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
@@ -424,13 +424,13 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-2b",
|
||||
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
|
||||
},
|
||||
"CodeGemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b",
|
||||
},
|
||||
"CodeGemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
||||
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
||||
},
|
||||
},
|
||||
@@ -581,6 +581,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
||||
},
|
||||
"LLaMA3-70B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
@@ -1174,6 +1177,30 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
"Yi-1.5-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
|
||||
},
|
||||
"Yi-1.5-9B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
|
||||
},
|
||||
"Yi-1.5-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
|
||||
},
|
||||
"Yi-1.5-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
||||
},
|
||||
"Yi-1.5-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
|
||||
},
|
||||
"Yi-1.5-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ def _get_package_version(name: str) -> "Version":
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
def is_fastapi_available():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
||||
|
||||
@@ -41,9 +41,9 @@ def llama_attention_forward(
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -87,7 +87,7 @@ def llama_attention_forward(
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
@@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -270,11 +270,12 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, :groupsz]
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
@@ -302,7 +303,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1")
|
||||
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -68,6 +68,8 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
model = model.to(output_dtype)
|
||||
else:
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
||||
@@ -71,14 +71,12 @@ def create_web_demo() -> gr.Blocks:
|
||||
|
||||
|
||||
def run_web_ui() -> None:
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
@@ -4,10 +4,9 @@ from llmtuner.webui.interface import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
||||
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
|
||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
||||
create_ui().queue().launch(share=gradio_share, server_name=server_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user