merge data part to the text stream
Former-commit-id: 7ee20286d9bcc2d5378bfd6bb02cd3648396d873
This commit is contained in:
@@ -29,7 +29,10 @@ def get_processor(model_path):
|
||||
def apply_lora(base_model_path, model_path, lora_path):
|
||||
print(f"Loading the base model from {base_model_path}")
|
||||
base_model = AutoModelForVision2Seq.from_pretrained(
|
||||
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cuda",
|
||||
base_model_path,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="cuda",
|
||||
)
|
||||
processor = get_processor(base_model_path)
|
||||
tokenizer = processor.tokenizer
|
||||
@@ -60,11 +63,14 @@ def main(
|
||||
if not os.path.exists(model_path) or do_merge:
|
||||
apply_lora(base_model_path, model_path, lora_model_path)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda"
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="cuda",
|
||||
)
|
||||
processor = get_processor(model_path)
|
||||
raw_datasets = load_dataset(dataset_name)
|
||||
train_dataset = raw_datasets['train']
|
||||
train_dataset = raw_datasets["train"]
|
||||
examples = train_dataset.select(range(3))
|
||||
texts = []
|
||||
images = []
|
||||
@@ -81,5 +87,5 @@ def main(
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
||||
Reference in New Issue
Block a user