update scripts
Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
@@ -41,7 +41,7 @@ def calculate_lr(
|
||||
dataset: str = "alpaca_en_demo",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
cutoff_len: int = 2048, # i.e. maximum input length during training
|
||||
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
||||
packing: bool = False,
|
||||
):
|
||||
@@ -59,6 +59,7 @@ def calculate_lr(
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
packing=packing,
|
||||
preprocessing_num_workers=16,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
do_train=True,
|
||||
@@ -79,7 +80,7 @@ def calculate_lr(
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
for batch in tqdm(dataloader):
|
||||
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
|
||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||
total_tokens += torch.numel(batch["labels"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user