fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
This commit is contained in:
@@ -29,7 +29,7 @@ class Template:
|
||||
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id]
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
return prompt_ids, encoded_pairs[-1][1]
|
||||
|
||||
@@ -73,6 +73,11 @@ class Template:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
"""
|
||||
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
|
||||
bos_token_id = [tokenizer.bos_token_id]
|
||||
else:
|
||||
bos_token_id = []
|
||||
eos_token_id = [tokenizer.eos_token_id] # eos token is required
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
@@ -81,7 +86,7 @@ class Template:
|
||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids))
|
||||
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
|
||||
return encoded_pairs
|
||||
|
||||
def _convert_inputs_to_ids(
|
||||
|
||||
Reference in New Issue
Block a user