Former-commit-id: e82f527ea583a7e99a25a06c7fe7b03c1dc2ebb9
This commit is contained in:
BUAADreamer
2024-05-13 23:28:52 +08:00
37 changed files with 181 additions and 154 deletions

View File

@@ -6,7 +6,7 @@ from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
@@ -22,7 +22,7 @@ from .protocol import (
)
if is_fastapi_availble():
if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer

View File

@@ -3,7 +3,8 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.packages import is_fastapi_availble
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
@@ -20,7 +21,7 @@ from .protocol import (
)
if is_fastapi_availble():
if is_fastapi_available():
from fastapi import HTTPException, status
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
@@ -39,6 +41,8 @@ ROLE_MAPPING = {
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")

View File

@@ -11,7 +11,7 @@ from .aligner import align_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum, merge_dataset
from .utils import merge_dataset
if TYPE_CHECKING:
@@ -61,8 +61,6 @@ def load_single_dataset(
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.file_sha1)
else:
raise NotImplementedError

View File

@@ -21,7 +21,6 @@ class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: bool = False
@@ -99,7 +98,6 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)

View File

@@ -308,7 +308,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")

View File

@@ -1,6 +1,5 @@
import hashlib
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -26,21 +25,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)

View File

@@ -139,13 +139,15 @@ class LogCallback(TrainerCallback):
r"""
Event called after an evaluation phase.
"""
self._close_thread_pool()
if not self.do_train:
self._close_thread_pool()
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
self._close_thread_pool()
if not self.do_train:
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""

View File

@@ -320,14 +320,14 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-MoE-236B": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
},
"DeepSeek-MoE-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
@@ -424,13 +424,13 @@ register_model_group(
register_model_group(
models={
"CodeGemma-2B": {
DownloadSource.DEFAULT: "google/codegemma-2b",
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Chat": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
},
@@ -581,6 +581,9 @@ register_model_group(
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
},
"LLaMA3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
},
},
template="llama3",
)
@@ -1174,6 +1177,30 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
},
"Yi-1.5-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
},
"Yi-1.5-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
},
"Yi-1.5-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
},
"Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
},
"Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
},
"Yi-1.5-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
},
},
template="yi",
)

View File

@@ -20,7 +20,7 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0")
def is_fastapi_availble():
def is_fastapi_available():
return _is_package_available("fastapi")

View File

@@ -41,9 +41,9 @@ def llama_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -87,7 +87,7 @@ def llama_attention_forward(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -270,11 +270,12 @@ def llama_sdpa_attention_forward(
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, :groupsz]
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
@@ -302,7 +303,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers==4.40.1", "To fix: pip install transformers==4.40.1")
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -68,6 +68,8 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
else:
setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained(
save_directory=model_args.export_dir,

View File

@@ -71,14 +71,12 @@ def create_web_demo() -> gr.Blocks:
def run_web_ui() -> None:
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name)
def run_web_demo() -> None:
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_web_demo().queue().launch(share=gradio_share, server_name=server_name)

View File

@@ -4,10 +4,9 @@ from llmtuner.webui.interface import create_ui
def main():
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0")))
create_ui().queue().launch(share=gradio_share, server_name=server_name, server_port=server_port)
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name)
if __name__ == "__main__":