Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
14
src/api.py
14
src/api.py
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
@@ -78,9 +92,11 @@ def _process_request(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
tool_calls = [
|
||||
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
elif isinstance(message.content, list):
|
||||
for input_item in message.content:
|
||||
@@ -104,7 +120,7 @@ def _process_request(
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = None
|
||||
@@ -146,15 +162,17 @@ async def create_chat_completion_response(
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
result = chat_model.engine.template.extract_tool(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
if isinstance(result, list):
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
function = Function(name=tool[0], arguments=tool[1])
|
||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .chat_model import ChatModel
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 THUDM and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the THUDM's ChatGLM implementation.
|
||||
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
@@ -32,7 +49,6 @@ class ChatModel:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
@@ -40,11 +54,19 @@ class HuggingfaceEngine(BaseEngine):
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
) # must after fixing tokenizer to resize vocab
|
||||
self.generating_args = generating_args.to_dict()
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
logger.warning("There is no current event loop, creating a new one.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
return scores
|
||||
|
||||
async def start(self) -> None:
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
@@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
stream = self._stream_chat(*input_args)
|
||||
while True:
|
||||
@@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||
async with self._semaphore:
|
||||
async with self.semaphore:
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
else:
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,14 +59,14 @@ class VllmEngine(BaseEngine):
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"download_dir": model_args.cache_dir,
|
||||
"dtype": model_args.vllm_dtype,
|
||||
"dtype": model_args.infer_dtype,
|
||||
"max_model_len": model_args.vllm_maxlen,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||
@@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
multi_modal_data = ImagePixelData(image=pixel_values)
|
||||
else: # TODO: remove vllm 0.4.3 support
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
@@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
return result_generator
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
@@ -60,7 +74,7 @@ class Command(str, Enum):
|
||||
|
||||
|
||||
def main():
|
||||
command = sys.argv.pop(1)
|
||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
@@ -77,7 +91,7 @@ def main():
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
@@ -92,6 +106,7 @@ def main():
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
@@ -10,6 +24,7 @@ from .data_utils import Role
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
@@ -175,7 +190,10 @@ def convert_sharegpt(
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
@@ -208,7 +226,7 @@ def align_dataset(
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -16,6 +30,9 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
@@ -25,13 +42,6 @@ class Role(str, Enum):
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - min(max_target_len, target_len)
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
|
||||
@@ -1,83 +1,36 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||
)
|
||||
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool{format_prompt}.\n"
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
items = (
|
||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||
)
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[Literal["default"]] = None
|
||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -128,34 +81,37 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
has_name, has_args = False, False
|
||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||
if "{{name}}" in slot:
|
||||
has_name = True
|
||||
if "{{arguments}}" in slot:
|
||||
has_args = True
|
||||
|
||||
if not has_name or not has_args:
|
||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||
if self.tool_format == "default":
|
||||
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||
elif self.tool_format == "glm4":
|
||||
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
functions = []
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
for name, arguments in functions:
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
@@ -163,25 +119,22 @@ class FunctionFormatter(Formatter):
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format is None:
|
||||
raise ValueError("Tool format was not found.")
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
return self._tool_extractor(content)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
@@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
@@ -32,6 +45,7 @@ def load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
@@ -123,7 +137,7 @@ def load_single_dataset(
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
@@ -134,7 +148,7 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
@@ -157,7 +171,8 @@ def get_dataset(
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
@@ -169,7 +184,7 @@ def get_dataset(
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
@@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -42,12 +55,8 @@ def _encode_feedback_example(
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
@@ -57,6 +66,12 @@ def _encode_feedback_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
# do not consider the kl_response
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
response_ids = response_ids[:target_len]
|
||||
kl_response_ids = kl_response_ids[:target_len]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -31,12 +44,8 @@ def _encode_pairwise_example(
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
@@ -46,6 +55,13 @@ def _encode_pairwise_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||
) # consider the response is more important
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
|
||||
@@ -1,9 +1,26 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
@@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
if target_len * 2 < cutoff_len: # truncate source
|
||||
max_target_len = cutoff_len
|
||||
elif source_len * 2 < cutoff_len: # truncate target
|
||||
max_target_len = cutoff_len - source_len
|
||||
else: # truncate both
|
||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||
|
||||
new_target_len = min(max_target_len, target_len)
|
||||
new_source_len = max(cutoff_len - new_target_len, 0)
|
||||
return new_source_len, new_target_len
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -38,10 +51,17 @@ def _encode_supervised_example(
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= data_args.cutoff_len:
|
||||
break
|
||||
|
||||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
||||
source_ids = source_ids[:source_len]
|
||||
target_ids = target_ids[:target_len]
|
||||
total_length += source_len + target_len
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
@@ -34,9 +47,7 @@ def _encode_unsupervised_example(
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
@@ -44,6 +55,9 @@ def _encode_unsupervised_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role, infer_max_len
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
|
||||
|
||||
@@ -24,69 +38,74 @@ class Template:
|
||||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
image_token: str
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids += query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
answer_ids = encoded_pairs[-1][1]
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
|
||||
answer_ids = encoded_messages[-1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
|
||||
if i == 0:
|
||||
elements += self.format_prefix.apply()
|
||||
if system or tools:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
@@ -102,11 +121,9 @@ class Template:
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
return encoded_messages
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
) -> List[int]:
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
@@ -127,57 +144,34 @@ class Template:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _make_pairs(
|
||||
self,
|
||||
encoded_messages: Sequence[List[int]],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
source_ids = encoded_messages[i][:max_source_len]
|
||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(source_ids) + len(target_ids)
|
||||
encoded_pairs.append((source_ids, target_ids))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
if i == 0:
|
||||
elements += self.format_prefix.apply()
|
||||
if system or tools:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
@@ -193,7 +187,7 @@ class Llama2Template(Template):
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
return encoded_messages
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
@@ -208,12 +202,12 @@ def _register_template(
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
stop_words: Sequence[str] = [],
|
||||
image_token: str = "<image>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
force_system: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@@ -245,9 +239,10 @@ def _register_template(
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
default_prefix_formatter = EmptyFormatter()
|
||||
TEMPLATES[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
@@ -256,12 +251,12 @@ def _register_template(
|
||||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
image_token=image_token,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
)
|
||||
|
||||
|
||||
@@ -307,6 +302,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
jinja_template = ""
|
||||
|
||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||
if prefix:
|
||||
jinja_template += "{{ " + prefix + " }}"
|
||||
|
||||
if template.default_system:
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
||||
@@ -315,11 +314,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
if isinstance(template, Llama2Template):
|
||||
pass
|
||||
elif template.force_system:
|
||||
jinja_template += "{{ " + system_message + " }}"
|
||||
else:
|
||||
if not isinstance(template, Llama2Template):
|
||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
|
||||
jinja_template += "{% for message in messages %}"
|
||||
@@ -346,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
tool_format: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
@@ -354,6 +350,12 @@ def get_template_and_fix_tokenizer(
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
if tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(tool_format))
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
if not stop_words:
|
||||
@@ -435,9 +437,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
@@ -450,11 +451,7 @@ _register_template(
|
||||
_register_template(
|
||||
name="breeze",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
default_system=(
|
||||
"You are a helpful AI assistant built by MediaTek Research. "
|
||||
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
@@ -462,10 +459,9 @@ _register_template(
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -473,32 +469,13 @@ _register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
@@ -529,8 +506,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
)
|
||||
|
||||
|
||||
@@ -544,21 +520,15 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
|
||||
),
|
||||
default_system=(
|
||||
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
|
||||
"by providing thorough responses. You are trained by Cohere."
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
@@ -591,30 +561,28 @@ _register_template(
|
||||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
@@ -622,11 +590,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -648,13 +612,12 @@ _register_template(
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -662,36 +625,33 @@ _register_template(
|
||||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||
"by the user such as English and 中文."
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||
)
|
||||
@@ -700,7 +660,6 @@ _register_template(
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
)
|
||||
|
||||
@@ -723,9 +682,7 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@@ -734,7 +691,7 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
default_system="You are a helpful assistant.",
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
@@ -743,24 +700,21 @@ _register_template(
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
@@ -774,27 +728,25 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
@@ -827,7 +779,6 @@ _register_template(
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
140
src/llamafactory/data/tool_utils.py
Normal file
140
src/llamafactory/data/tool_utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS: ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required, enum, items = "", "", ""
|
||||
if name in tool["parameters"].get("required", []):
|
||||
required = ", required"
|
||||
|
||||
if param.get("enum", None):
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
||||
|
||||
if param.get("items", None):
|
||||
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
||||
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for match in action_match:
|
||||
tool_name = match[0].strip()
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
@@ -1,4 +1,41 @@
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Dan's test library.
|
||||
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Dan Hendrycks
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
@@ -26,9 +63,7 @@ class Evaluator:
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
]
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
@@ -10,7 +24,6 @@ class EvalTemplate:
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
@@ -42,8 +55,8 @@ class EvalTemplate:
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
@@ -56,8 +69,7 @@ _register_eval_template(
|
||||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" ",
|
||||
answer="\nAnswer:",
|
||||
)
|
||||
|
||||
|
||||
@@ -66,5 +78,4 @@ _register_eval_template(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix=" ",
|
||||
)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
@@ -404,6 +418,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
||||
},
|
||||
"DeepSeek-MoE-Coder-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
|
||||
},
|
||||
"DeepSeek-MoE-Coder-236B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
|
||||
},
|
||||
"DeepSeek-MoE-Coder-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
},
|
||||
"DeepSeek-MoE-Coder-236B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
@@ -496,6 +522,18 @@ register_model_group(
|
||||
"Gemma-1.1-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||
},
|
||||
"Gemma-2-9B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||
},
|
||||
"Gemma-2-27B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||
},
|
||||
"Gemma-2-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||
},
|
||||
"Gemma-2-27B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
@@ -568,7 +606,7 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Jambda-v0.1": {
|
||||
"Jamba-v0.1": {
|
||||
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
||||
}
|
||||
@@ -683,6 +721,21 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-2B-SFT-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
||||
},
|
||||
"MiniCPM-2B-DPO-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
||||
},
|
||||
},
|
||||
template="cpm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
|
||||
import accelerate
|
||||
@@ -9,7 +26,7 @@ import trl
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
|
||||
VERSION = "0.8.1.dev0"
|
||||
VERSION = "0.8.3.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -1,13 +1,29 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's PEFT library.
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
import transformers.dynamic_module_utils
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.dynamic_module_utils import get_relative_imports
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -16,7 +32,6 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
@@ -28,8 +43,6 @@ except Exception:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
@@ -58,6 +71,9 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
@@ -68,7 +84,7 @@ def check_dependencies() -> None:
|
||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
@@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
@@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
@@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
@@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
def skip_check_imports() -> None:
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
@@ -24,10 +42,6 @@ def is_fastapi_available():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
@@ -36,18 +50,10 @@ def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _is_package_available("nltk")
|
||||
|
||||
|
||||
def is_pillow_available():
|
||||
return _is_package_available("PIL")
|
||||
|
||||
@@ -60,10 +66,6 @@ def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_sdpa_available():
|
||||
return _get_package_version("torch") > version.parse("2.1.1")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
@@ -74,3 +76,8 @@ def is_uvicorn_available():
|
||||
|
||||
def is_vllm_available():
|
||||
return _is_package_available("vllm")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_vllm_version_greater_than_0_5():
|
||||
return _get_package_version("vllm") >= version.parse("0.5.0")
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
@@ -28,10 +45,6 @@ class DataArguments:
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
reserved_label_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||
@@ -90,15 +103,16 @@ class DataArguments:
|
||||
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
|
||||
},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.reserved_label_len >= self.cutoff_len:
|
||||
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,6 +108,18 @@ class LoraArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||
)
|
||||
pissa_init: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
||||
)
|
||||
pissa_iter: int = field(
|
||||
default=16,
|
||||
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
||||
)
|
||||
pissa_convert: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
|
||||
)
|
||||
create_new_adapter: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
@@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.galore_target = split_arg(self.galore_target)
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
@@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_init and self.finetuning_type != "lora":
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,6 +45,10 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
@@ -50,6 +77,10 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
||||
@@ -70,7 +101,7 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
@@ -127,13 +158,9 @@ class ModelArguments:
|
||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||
)
|
||||
vllm_max_lora_rank: int = field(
|
||||
default=8,
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
|
||||
)
|
||||
offload_folder: str = field(
|
||||
default="offload",
|
||||
metadata={"help": "Path to offload model weights."},
|
||||
@@ -142,6 +169,10 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
@@ -192,9 +223,9 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
self.device_map = None
|
||||
self.model_max_length = None
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
@@ -208,11 +239,18 @@ class ModelArguments:
|
||||
if self.new_special_tokens is not None: # support multiple special tokens
|
||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
return new_arg
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -8,6 +25,7 @@ import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -65,13 +83,13 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
|
||||
|
||||
if model_args.resize_vocab:
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
@@ -100,7 +118,7 @@ def _check_extra_dependencies(
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam", "To fix: pip install badam")
|
||||
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
@@ -162,6 +180,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
|
||||
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
@@ -171,32 +195,31 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
raise ValueError("This device does not support `pure_bf16`.")
|
||||
|
||||
if training_args.fp16 or training_args.bf16:
|
||||
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_galore
|
||||
and finetuning_args.galore_layerwise
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
@@ -204,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if model_args.visual_inputs and data_args.packing:
|
||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
@@ -233,7 +259,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
training_args.parallel_mode.value == "distributed"
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
@@ -293,7 +319,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
training_args.parallel_mode.value == "distributed",
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
@@ -332,6 +358,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from llamafactory.train.tuner import run_exp
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .model_utils.misc import find_all_linear_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationMethod",
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -25,8 +39,12 @@ def _setup_full_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
if is_trainable:
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
init_kwargs = {
|
||||
"subfolder": model_args.adapter_folder,
|
||||
"offload_folder": model_args.offload_folder,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.hf_hub_token,
|
||||
}
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@@ -209,16 +238,24 @@ def _setup_lora_tuning(
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"use_dora": finetuning_args.use_dora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
if finetuning_args.pissa_iter == -1:
|
||||
logger.info("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -247,29 +281,36 @@ def init_adapter(
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "lora":
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
@@ -33,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
skip_check_imports()
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
@@ -162,17 +178,21 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
for param in model.parameters():
|
||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:d}".format(all_param)
|
||||
param_stats = "all params: {:,}".format(all_param)
|
||||
|
||||
logger.info(param_stats)
|
||||
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,21 +28,33 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||
if model_args.flash_attn == "auto":
|
||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn != "disabled":
|
||||
logger.warning(
|
||||
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
||||
"Will proceed at your own risk.".format(model_args.flash_attn)
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
elif model_args.flash_attn == "disabled":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn2_available():
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
@@ -60,15 +78,12 @@ def _fp32_forward_post_hook(
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
@@ -87,8 +102,8 @@ def prepare_model_for_training(
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
# This code is also inspired by the original LongLoRA implementation.
|
||||
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@@ -96,7 +115,8 @@ def llama_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -181,11 +201,9 @@ def llama_flash_attention_2_forward(
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
else:
|
||||
groupsz = q_len
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
@@ -194,7 +212,8 @@ def llama_flash_attention_2_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
@@ -293,7 +312,8 @@ def llama_sdpa_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -303,7 +323,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -10,6 +25,13 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, leaf_modules)
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
@@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jetmoe":
|
||||
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
|
||||
|
||||
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and Optimum library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
|
||||
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
@@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
@@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
data_files=data_files,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
samples = []
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
n_try = 0
|
||||
while True:
|
||||
if n_try > 100:
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
|
||||
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
|
||||
|
||||
return samples
|
||||
|
||||
@@ -76,14 +105,14 @@ def configure_quantization(
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
@@ -105,46 +134,72 @@ def configure_quantization(
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
raise ValueError("ChatGLM model is not supported yet.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
# Do not assign device map if:
|
||||
# 1. deepspeed zero3 or fsdp (train)
|
||||
# 2. auto quantization device map (inference)
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 LMSYS and the LlamaFactory team.
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# This code is inspired by the LMSYS's FastChat library.
|
||||
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
@@ -46,13 +60,16 @@ def patch_config(
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
if model_args.infer_dtype != "auto" and not is_trainable:
|
||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||
else:
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
@@ -74,14 +91,17 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
|
||||
# cast data type of the model if:
|
||||
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
|
||||
# 2. quantization_bit is not None (qlora)
|
||||
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map
|
||||
|
||||
if init_kwargs["device_map"] == "auto":
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
if finetune_args.stage == "sft" and data_args.efficient_packing:
|
||||
@@ -137,6 +157,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_output_embeddings()
|
||||
|
||||
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
@@ -145,4 +169,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -8,22 +22,78 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
)
|
||||
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
@@ -37,8 +107,70 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
)
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
r"""
|
||||
Initializes a callback for saving the processor.
|
||||
"""
|
||||
self.processor = processor
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
||||
# 1. save a pissa backup with init_lora_weights: True
|
||||
# 2. save a converted lora with init_lora_weights: pissa
|
||||
# 3. load the pissa backup with init_lora_weights: True
|
||||
# 4. delete the initial adapter and change init_lora_weights to pissa
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
model.save_pretrained(
|
||||
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
||||
)
|
||||
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||
model.set_adapter("default")
|
||||
model.delete_adapter("pissa_init")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
@@ -56,7 +188,7 @@ class LogCallback(TrainerCallback):
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_dpo
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
@@ -10,7 +28,8 @@ from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
@@ -61,6 +79,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
@@ -71,10 +91,17 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if finetuning_args.pissa_convert:
|
||||
self.callback_handler.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -87,12 +114,6 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||
@@ -176,7 +197,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_kto
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
@@ -9,7 +27,8 @@ from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
@@ -60,6 +78,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
@@ -70,10 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -92,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
@@ -143,7 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_ppo
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
@@ -1,6 +1,24 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -9,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.trainer_callback import CallbackHandler
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
@@ -16,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
@@ -81,10 +100,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
# Add deepspeed config
|
||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
if training_args.deepspeed_plugin is not None:
|
||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
|
||||
# Create optimizer and scheduler
|
||||
@@ -113,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reward_model = reward_model
|
||||
self.current_device = get_current_device() # patch for deepspeed training
|
||||
self.processor = processor
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
@@ -125,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||
self.callback_handler = CallbackHandler(
|
||||
[callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
@@ -134,8 +153,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
||||
|
||||
device_type = unwrapped_model.pretrained_model.device.type
|
||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
@@ -147,10 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
@@ -184,23 +209,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = {}".format(num_examples))
|
||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(" Num examples = {:,}".format(num_examples))
|
||||
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
|
||||
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {:,}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||
try:
|
||||
@@ -238,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||
logs = dict(
|
||||
@@ -250,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
self.log_callback.on_log(self.args, self.state, self.control)
|
||||
self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
@@ -258,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.save_model(
|
||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||
)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.save_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
def create_optimizer(
|
||||
self,
|
||||
@@ -486,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
if self.processor is not None and self.args.should_save:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's TRL library.
|
||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from ...data import get_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_ref_model, create_reward_model
|
||||
from .trainer import CustomPPOTrainer
|
||||
|
||||
@@ -60,6 +74,7 @@ def run_ppo(
|
||||
ppo_trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_pt
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -27,11 +42,18 @@ class CustomTrainer(Trainer):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -43,9 +65,3 @@ class CustomTrainer(Trainer):
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_rm
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,3 +1,42 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
@@ -7,6 +46,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -30,12 +70,20 @@ class PairwiseTrainer(Trainer):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -48,12 +96,6 @@ class PairwiseTrainer(Trainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
@@ -63,7 +105,7 @@ class PairwiseTrainer(Trainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
|
||||
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
@@ -79,7 +121,6 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
@@ -125,4 +166,5 @@ class PairwiseTrainer(Trainer):
|
||||
res: List[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -1,12 +1,48 @@
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import compute_accuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
@@ -40,7 +76,7 @@ def run_rm(
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
callbacks=callbacks,
|
||||
compute_metrics=compute_accuracy,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
@@ -52,6 +88,7 @@ def run_rm(
|
||||
trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,35 @@
|
||||
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import EvalPrediction
|
||||
from transformers.utils import is_jieba_available, is_nltk_available
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||
from ...extras.packages import is_rouge_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
if is_jieba_available():
|
||||
@@ -23,6 +44,22 @@ if is_rouge_available():
|
||||
from rouge_chinese import Rouge
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
accuracies = []
|
||||
for i in range(len(preds)):
|
||||
pred, label = preds[i, :-1], labels[i, 1:]
|
||||
label_mask = label != IGNORE_INDEX
|
||||
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
|
||||
|
||||
return {"accuracy": float(np.mean(accuracies))}
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
@@ -31,11 +68,11 @@ class ComputeMetrics:
|
||||
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
@@ -9,10 +26,12 @@ from transformers import Seq2SeqTrainer
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
@@ -32,11 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -49,12 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
@@ -94,7 +114,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -115,18 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
for i in range(len(preds)):
|
||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||
if len(pad_len):
|
||||
preds[i] = np.concatenate(
|
||||
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
|
||||
) # move pad token to last
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_labels = self.tokenizer.batch_decode(
|
||||
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for label, pred in zip(decoded_labels, decoded_preds):
|
||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
||||
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
||||
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -10,7 +25,7 @@ from ...extras.misc import get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import ComputeMetrics
|
||||
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -56,7 +71,8 @@ def run_sft(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
@@ -75,7 +91,7 @@ def run_sft(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
@@ -92,7 +108,7 @@ def run_sft(
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
trainer.save_predictions(dataset, predict_results)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
from contextlib import contextmanager
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
|
||||
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
|
||||
# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
|
||||
# and the HuggingFace's TRL library: https://github.com/huggingface/trl
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
@@ -19,7 +38,6 @@ if is_galore_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -83,15 +101,12 @@ def create_ref_model(
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
if finetuning_args.ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(
|
||||
dict(
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||
)
|
||||
ref_model_args = ModelArguments.copyfrom(
|
||||
model_args,
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||
)
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments()
|
||||
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
@@ -102,9 +117,11 @@ def create_ref_model(
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
ref_model_args = ModelArguments.copyfrom(model_args)
|
||||
ref_finetuning_args = FinetuningArguments()
|
||||
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
|
||||
ref_model = load_model(
|
||||
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
@@ -139,15 +156,12 @@ def create_reward_model(
|
||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return None
|
||||
else:
|
||||
reward_model_args_dict = model_args.to_dict()
|
||||
reward_model_args_dict.update(
|
||||
dict(
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||
)
|
||||
reward_model_args = ModelArguments.copyfrom(
|
||||
model_args,
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||
)
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments()
|
||||
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
|
||||
reward_model = load_model(
|
||||
@@ -158,17 +172,6 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
||||
r"""
|
||||
Gets adapter context for the reference model.
|
||||
"""
|
||||
with accelerator.unwrap_model(model).disable_adapter():
|
||||
model.eval()
|
||||
yield
|
||||
model.train()
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
@@ -184,7 +187,7 @@ def _create_galore_optimizer(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
galore_targets = find_all_linear_modules(model)
|
||||
galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||
else:
|
||||
galore_targets = finetuning_args.galore_target
|
||||
|
||||
@@ -334,6 +337,7 @@ def _create_badam_optimizer(
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
@@ -355,7 +359,7 @@ def _create_badam_optimizer(
|
||||
**optim_kwargs,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,30 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
@@ -24,8 +41,8 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
callbacks.append(LogCallback())
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks.append(LogCallback(training_args.output_dir))
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
@@ -84,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
|
||||
if finetuning_args.stage == "rm":
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
vhead_path = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
vhead_path = model_args.model_name_or_path
|
||||
|
||||
if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
|
||||
shutil.copy(
|
||||
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||
shutil.copy(
|
||||
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||
)
|
||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
@@ -9,7 +23,7 @@ from ..data import Role
|
||||
from ..extras.constants import PEFT_METHODS
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_save_dir
|
||||
from .common import QUANTIZATION_BITS, get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@@ -62,17 +76,24 @@ class WebChatModel(ChatModel):
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
@@ -126,16 +147,15 @@ class WebChatModel(ChatModel):
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
result = self.engine.template.format_tools.extract(response)
|
||||
result = self.engine.template.extract_tool(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
if isinstance(result, list):
|
||||
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
|
||||
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||
bot_text = "```json\n" + tool_calls + "\n```"
|
||||
else:
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
@@ -33,13 +47,19 @@ DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def get_save_dir(*paths: str) -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to saved model checkpoints.
|
||||
"""
|
||||
paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
|
||||
if os.path.sep in paths[-1]:
|
||||
logger.warning("Found complex path, some features may be not available.")
|
||||
return paths[-1]
|
||||
|
||||
paths = (path.replace(" ", "").strip() for path in paths)
|
||||
return os.path.join(DEFAULT_SAVE_DIR, *paths)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .chatbot import create_chat_box
|
||||
from .eval import create_eval_tab
|
||||
from .export import create_export_tab
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
from ...data import Role
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Union
|
||||
|
||||
from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..common import GPTQ_BITS, get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
@@ -18,7 +32,11 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
|
||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
else:
|
||||
return gr.Dropdown(interactive=True)
|
||||
|
||||
|
||||
def save_model(
|
||||
@@ -96,6 +114,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
export_dir = gr.Textbox()
|
||||
export_hub_model_id = gr.Textbox()
|
||||
|
||||
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
|
||||
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
@@ -18,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||
with gr.Row():
|
||||
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
input_elems.update({infer_backend})
|
||||
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
input_elems.update({infer_backend, infer_dtype})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
infer_backend=infer_backend,
|
||||
infer_dtype=infer_dtype,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
info_box=info_box,
|
||||
)
|
||||
)
|
||||
|
||||
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(chat_elems)
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...data import TEMPLATES
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_info, list_checkpoints, save_config
|
||||
from ..utils import can_quantize
|
||||
from ..utils import can_quantize, can_quantize_to
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -29,17 +43,23 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
)
|
||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
)
|
||||
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||
quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
@@ -49,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
checkpoint_path=checkpoint_path,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=quantization_method,
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
@@ -40,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
||||
compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
|
||||
|
||||
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||
elem_dict.update(
|
||||
@@ -152,10 +166,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
use_pissa = gr.Checkbox()
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
@@ -168,6 +181,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
use_pissa,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
@@ -182,6 +196,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
use_pissa=use_pissa,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
@@ -279,7 +294,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
|
||||
input_elems.update({output_dir, config_path, ds_stage, ds_offload})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
CSS = r"""
|
||||
.duplicate-button {
|
||||
margin: auto !important;
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .chatter import WebChatModel
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
LOCALES = {
|
||||
"lang": {
|
||||
"en": {
|
||||
@@ -71,15 +85,29 @@ LOCALES = {
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization (QLoRA).",
|
||||
"info": "Enable quantization (QLoRA).",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Уровень квантования",
|
||||
"info": "Включить 4/8-битное квантование модели (QLoRA).",
|
||||
"info": "Включить квантование (QLoRA).",
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级",
|
||||
"info": "启用 4/8 比特模型量化(QLoRA)。",
|
||||
"info": "启用量化(QLoRA)。",
|
||||
},
|
||||
},
|
||||
"quantization_method": {
|
||||
"en": {
|
||||
"label": "Quantization method",
|
||||
"info": "Quantization algorithm to use.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Метод квантования",
|
||||
"info": "Алгоритм квантования, который следует использовать.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化方法",
|
||||
"info": "使用的量化算法。",
|
||||
},
|
||||
},
|
||||
"template": {
|
||||
@@ -732,6 +760,20 @@ LOCALES = {
|
||||
"info": "使用权重分解的 LoRA。",
|
||||
},
|
||||
},
|
||||
"use_pissa": {
|
||||
"en": {
|
||||
"label": "Use PiSSA",
|
||||
"info": "Use PiSSA method.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "используйте PiSSA",
|
||||
"info": "Используйте метод PiSSA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 PiSSA",
|
||||
"info": "使用 PiSSA 方法。",
|
||||
},
|
||||
},
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
@@ -1192,6 +1234,17 @@ LOCALES = {
|
||||
"label": "推理引擎",
|
||||
},
|
||||
},
|
||||
"infer_dtype": {
|
||||
"en": {
|
||||
"label": "Inference data type",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Тип данных для вывода",
|
||||
},
|
||||
"zh": {
|
||||
"label": "推理数据类型",
|
||||
},
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model",
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
|
||||
|
||||
|
||||
@@ -57,6 +71,7 @@ class Manager:
|
||||
self._id_to_elem["top.finetuning_type"],
|
||||
self._id_to_elem["top.checkpoint_path"],
|
||||
self._id_to_elem["top.quantization_bit"],
|
||||
self._id_to_elem["top.quantization_method"],
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
@@ -8,9 +22,9 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -38,7 +52,7 @@ class Runner:
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
abort_leaf_process(self.trainer.pid)
|
||||
abort_process(self.trainer.pid)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
@@ -90,6 +104,11 @@ class Runner:
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
do_train=True,
|
||||
@@ -97,7 +116,8 @@ class Runner:
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
@@ -160,6 +180,8 @@ class Runner:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["pissa_init"] = get("train.use_pissa")
|
||||
args["pissa_convert"] = get("train.use_pissa")
|
||||
args["lora_target"] = get("train.lora_target") or "all"
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
|
||||
@@ -219,13 +241,19 @@ class Runner:
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
@@ -283,6 +311,7 @@ class Runner:
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
env["LLAMABOARD_WORKDIR"] = args["output_dir"]
|
||||
if args.get("deepspeed", None) is not None:
|
||||
env["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
@@ -291,7 +320,7 @@ class Runner:
|
||||
|
||||
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
config_dict = {}
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user