upgrading all other files to be able to use cpu/mps as well as cuda. various minor other changes ,e.g. changing max_iterations to num_iterations in sft script for consistency in naming
This commit is contained in:
@@ -34,6 +34,7 @@ device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
@@ -116,6 +117,7 @@ def mid_data_generator(split):
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
@@ -127,6 +129,10 @@ def mid_data_generator(split):
|
||||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
@@ -135,7 +141,10 @@ def mid_data_generator(split):
|
||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
|
||||
Reference in New Issue
Block a user