update scripts

Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
hiyouga
2025-01-03 10:50:32 +00:00
parent d1a8cd67d2
commit 8516054e4d
5 changed files with 19 additions and 13 deletions

View File

@@ -41,7 +41,7 @@ def calculate_lr(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
cutoff_len: int = 2048, # i.e. maximum input length during training
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
@@ -59,6 +59,7 @@ def calculate_lr(
template=template,
cutoff_len=cutoff_len,
packing=packing,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -79,7 +80,7 @@ def calculate_lr(
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])

View File

@@ -63,7 +63,7 @@ def calculate_ppl(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
@@ -82,6 +82,7 @@ def calculate_ppl(
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -111,7 +112,7 @@ def calculate_ppl(
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]

View File

@@ -42,6 +42,7 @@ def length_cdf(
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -52,7 +53,7 @@ def length_cdf(
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):
for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())