stopping_criteria.py 1.0 KB

12345678910111213141516171819202122232425262728293031
  1. '''
  2. This code was copied from
  3. https://github.com/PygmalionAI/gradio-ui/
  4. '''
  5. import torch
  6. import transformers
  7. class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
  8. def __init__(self, sentinel_token_ids: torch.LongTensor,
  9. starting_idx: int):
  10. transformers.StoppingCriteria.__init__(self)
  11. self.sentinel_token_ids = sentinel_token_ids
  12. self.starting_idx = starting_idx
  13. def __call__(self, input_ids: torch.LongTensor,
  14. _scores: torch.FloatTensor) -> bool:
  15. for sample in input_ids:
  16. trimmed_sample = sample[self.starting_idx:]
  17. # Can't unfold, output is still too tiny. Skip.
  18. if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
  19. continue
  20. for window in trimmed_sample.unfold(
  21. 0, self.sentinel_token_ids.shape[-1], 1):
  22. if torch.all(torch.eq(self.sentinel_token_ids, window)):
  23. return True
  24. return False