Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba618947e7 | ||
|
|
f81041b502 | ||
|
|
f2533a2800 | ||
|
|
bb5b4a7f26 | ||
|
|
20bff87021 | ||
|
|
722b954800 | ||
|
|
19256086c7 | ||
|
|
250fecfcd4 | ||
|
|
cb4d1d5ebb | ||
|
|
d7d557fb2e |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,3 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
@@ -145,10 +145,10 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
### All-in-one Web UI
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
Currently the web UI only supports training on a single GPU.
|
||||
Currently the web UI only supports training on **a single GPU**.
|
||||
|
||||
### (Continually) Pre-Training
|
||||
|
||||
@@ -196,6 +196,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
Remember to specify `--lora_target W_pack` if you are using Baichuan models.
|
||||
|
||||
### Reward Model Training
|
||||
|
||||
```bash
|
||||
|
||||
@@ -145,10 +145,10 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
### 浏览器一键微调/测试
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
目前网页 UI 仅支持单卡训练。
|
||||
目前网页 UI 仅支持**单卡训练**。
|
||||
|
||||
### 二次预训练
|
||||
|
||||
@@ -196,6 +196,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
|
||||
|
||||
### 奖励模型训练
|
||||
|
||||
```bash
|
||||
|
||||
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
||||
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
||||
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
||||
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
||||
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
||||
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
||||
1
data/oaast_rm.json.REMOVED.git-id
Normal file
1
data/oaast_rm.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
274079ea921762be356de85b18f13fa60b7ba8cb
|
||||
1
data/oaast_sft.json.REMOVED.git-id
Normal file
1
data/oaast_sft.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
57fd080be5bffe4153fe3ee26a175e3d56da30f3
|
||||
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
f967a4f6d04a11308a15524aa9a846a19a8d1e83
|
||||
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8
|
||||
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||
@@ -1,4 +1,4 @@
|
||||
from llmtuner.chat import ChatModel
|
||||
|
||||
|
||||
__version__ = "0.1.4"
|
||||
__version__ = "0.1.5"
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
@@ -21,15 +21,7 @@ class ChatModel:
|
||||
generating_args: "GeneratingArguments"
|
||||
) -> None:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
||||
self.model = dispatch_model(self.model, device_map)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template(data_args.template)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.generating_args = generating_args
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
@@ -93,7 +93,11 @@ def get_dataset(
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
||||
features = None
|
||||
if data_args.streaming:
|
||||
features = dataset.features
|
||||
features["prefix"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
|
||||
@@ -18,15 +18,15 @@ def preprocess_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names or [])
|
||||
column_names = list(dataset.column_names)
|
||||
template = get_template(data_args.template)
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = history if "history" in examples and examples["history"][i] else []
|
||||
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
|
||||
history = examples["history"][i] if "history" in examples else None
|
||||
prefix = examples["prefix"][i] if "prefix" in examples else None
|
||||
yield query, response, history, prefix
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
@@ -143,15 +143,19 @@ def preprocess_dataset(
|
||||
if stage == "pt":
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
print_function = print_supervised_dataset_example
|
||||
elif stage == "rm":
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
print_function = print_pairwise_dataset_example
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
kwargs = {}
|
||||
@@ -172,13 +176,5 @@ def preprocess_dataset(
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(next(iter(dataset)))
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
|
||||
print_function(next(iter(dataset)))
|
||||
return dataset
|
||||
|
||||
@@ -117,3 +117,25 @@ def torch_gc() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
|
||||
@@ -46,18 +46,37 @@ class Template:
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "")]
|
||||
convs = [
|
||||
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
|
||||
for turn_idx, (query_i, resp_i) in enumerate(history)
|
||||
return [
|
||||
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
|
||||
for i, (q, r) in enumerate(history)
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
|
||||
def _format_example(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = ""
|
||||
) -> List[Tuple[str, str]]:
|
||||
prefix = prefix or self.prefix # use prefix if provided
|
||||
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix)
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "")]
|
||||
return [
|
||||
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r]
|
||||
for i, (q, r) in enumerate(history)
|
||||
]
|
||||
return convs
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||
templates[name] = Template(
|
||||
template_class = Llama2Template if name == "llama2" else Template
|
||||
templates[name] = template_class(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
sep=sep,
|
||||
@@ -111,8 +130,8 @@ register_template(
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||
prompt=" [INST] {query} [/INST] ",
|
||||
sep="",
|
||||
prompt="[INST] {query} [/INST] ",
|
||||
sep="<s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user