update scripts

Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
hiyouga
2025-01-03 10:50:32 +00:00
parent d1a8cd67d2
commit 8516054e4d
5 changed files with 19 additions and 13 deletions

View File

@@ -41,7 +41,7 @@ def calculate_lr(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
cutoff_len: int = 2048, # i.e. maximum input length during training
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
@@ -59,6 +59,7 @@ def calculate_lr(
template=template,
cutoff_len=cutoff_len,
packing=packing,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -79,7 +80,7 @@ def calculate_lr(
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])