merge data part to the text stream

Former-commit-id: 7ee20286d9bcc2d5378bfd6bb02cd3648396d873
This commit is contained in:
BUAADreamer
2024-04-25 19:19:59 +08:00
parent 00e2a272ef
commit 3c792174db
13 changed files with 802 additions and 284 deletions

View File

@@ -29,7 +29,10 @@ def get_processor(model_path):
def apply_lora(base_model_path, model_path, lora_path):
print(f"Loading the base model from {base_model_path}")
base_model = AutoModelForVision2Seq.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda",
base_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda",
)
processor = get_processor(base_model_path)
tokenizer = processor.tokenizer
@@ -60,11 +63,14 @@ def main(
if not os.path.exists(model_path) or do_merge:
apply_lora(base_model_path, model_path, lora_model_path)
model = AutoModelForVision2Seq.from_pretrained(
model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda"
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="cuda",
)
processor = get_processor(model_path)
raw_datasets = load_dataset(dataset_name)
train_dataset = raw_datasets['train']
train_dataset = raw_datasets["train"]
examples = train_dataset.select(range(3))
texts = []
images = []
@@ -81,5 +87,5 @@ def main(
print(res)
if __name__ == '__main__':
if __name__ == "__main__":
fire.Fire(main)