10 Commits

Author SHA1 Message Date
hiyouga
ba618947e7 release v0.1.5
Former-commit-id: d619e76bc4098c29a7fdc05f5a71208bd1079c9f
2023-08-02 16:10:31 +08:00
hoshi-hiyouga
f81041b502 Merge pull request #307 from GitYCC/feature/fix-llama2-prompt-template
[feature] Fix template of Llama2 to match the offical template

Former-commit-id: a750b1f1ed16e20233df4d2f1c20507122919f5a
2023-08-02 15:51:28 +08:00
YC Chen
f2533a2800 [fix] Remove useless code
Former-commit-id: 077e1556112913e4eeef47e581055183b39d5404
2023-08-02 14:35:35 +08:00
YC Chen
bb5b4a7f26 [feature] Fix template of Llama2 to match the offical template
Former-commit-id: 1a98d45aefd95eea3768fb93e5a9da257ec61181
2023-08-02 14:10:15 +08:00
hiyouga
20bff87021 fix bug in preprocessing
Former-commit-id: 94952894576dfc4b42118162aec9aa35c3503c40
2023-08-02 01:10:28 +08:00
hiyouga
722b954800 update readme
Former-commit-id: 5154a04869be8c47e591351565b7842339fb99e4
2023-08-01 18:48:27 +08:00
hiyouga
19256086c7 fix #296
Former-commit-id: 69e9ed9b96a7cfb3d3b43ec5ddd01aa0bfd9b784
2023-08-01 18:43:53 +08:00
hiyouga
250fecfcd4 Fix #294
Former-commit-id: 09762d9849655f5e6c71b9472d55b42489dd944b
2023-08-01 18:13:03 +08:00
hiyouga
cb4d1d5ebb restore from git lfs
Former-commit-id: 0c734a37113b773ae7c0bc8b8d1af39b15bc0fb2
2023-08-01 16:33:25 +08:00
hiyouga
d7d557fb2e Update .gitattributes
Former-commit-id: 92e68f9f30c2fc91ae1b40865bc5c2d94899ba22
2023-08-01 16:28:54 +08:00
20 changed files with 84 additions and 37 deletions

1
.gitattributes vendored
View File

@@ -1,3 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
*.json filter=lfs diff=lfs merge=lfs -text

View File

@@ -145,10 +145,10 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
### All-in-one Web UI
```bash
python src/train_web.py
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
Currently the web UI only supports training on a single GPU.
Currently the web UI only supports training on **a single GPU**.
### (Continually) Pre-Training
@@ -196,6 +196,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
Remember to specify `--lora_target W_pack` if you are using Baichuan models.
### Reward Model Training
```bash

View File

@@ -145,10 +145,10 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
### 浏览器一键微调/测试
```bash
python src/train_web.py
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
目前网页 UI 仅支持单卡训练。
目前网页 UI 仅支持**单卡训练**
### 二次预训练
@@ -196,6 +196,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
### 奖励模型训练
```bash

View File

@@ -0,0 +1 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -0,0 +1 @@
fc9a6a3458caca2af8dafc6181773fe10c6d8657

View File

@@ -0,0 +1 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -0,0 +1 @@
7cb6a7d11455bddc3d495750a2392683d775b184

View File

@@ -0,0 +1 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -0,0 +1 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -0,0 +1 @@
274079ea921762be356de85b18f13fa60b7ba8cb

View File

@@ -0,0 +1 @@
57fd080be5bffe4153fe3ee26a175e3d56da30f3

View File

@@ -0,0 +1 @@
f967a4f6d04a11308a15524aa9a846a19a8d1e83

View File

@@ -0,0 +1 @@
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8

View File

@@ -0,0 +1 @@
38c89869c6aeca2a3af9ea1e09afe460f9b46810

View File

@@ -1,4 +1,4 @@
from llmtuner.chat import ChatModel
__version__ = "0.1.4"
__version__ = "0.1.5"

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer
@@ -21,15 +21,7 @@ class ChatModel:
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args

View File

@@ -2,7 +2,7 @@ import os
import hashlib
from typing import TYPE_CHECKING, List, Optional
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger
@@ -93,7 +93,11 @@ def get_dataset(
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
features = None
if data_args.streaming:
features = dataset.features
features["prefix"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
all_datasets.append(dataset)

View File

@@ -18,15 +18,15 @@ def preprocess_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset":
column_names = list(dataset.column_names or [])
column_names = list(dataset.column_names)
template = get_template(data_args.template)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = history if "history" in examples and examples["history"][i] else []
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None
yield query, response, history, prefix
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
@@ -143,15 +143,19 @@ def preprocess_dataset(
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
@@ -172,13 +176,5 @@ def preprocess_dataset(
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt":
print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft":
print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm":
print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo":
print_unsupervised_dataset_example(next(iter(dataset)))
print_function(next(iter(dataset)))
return dataset

View File

@@ -117,3 +117,25 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@@ -46,18 +46,37 @@ class Template:
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, "")]
convs = [
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
for turn_idx, (query_i, resp_i) in enumerate(history)
return [
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
for i, (q, r) in enumerate(history)
]
@dataclass
class Llama2Template(Template):
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix)
history = history if (history and self.use_history) else []
history = history + [(query, "")]
return [
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r]
for i, (q, r) in enumerate(history)
]
return convs
templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Template(
template_class = Llama2Template if name == "llama2" else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
sep=sep,
@@ -111,8 +130,8 @@ register_template(
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ",
sep="",
prompt="[INST] {query} [/INST] ",
sep="<s>",
use_history=True
)