|
|
@@ -0,0 +1,31 @@
|
|
|
+'''
|
|
|
+This code was copied from
|
|
|
+
|
|
|
+https://github.com/PygmalionAI/gradio-ui/
|
|
|
+
|
|
|
+'''
|
|
|
+
|
|
|
+import torch
|
|
|
+import transformers
|
|
|
+
|
|
|
+class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|
|
+
|
|
|
+ def __init__(self, sentinel_token_ids: torch.LongTensor,
|
|
|
+ starting_idx: int):
|
|
|
+ transformers.StoppingCriteria.__init__(self)
|
|
|
+ self.sentinel_token_ids = sentinel_token_ids
|
|
|
+ self.starting_idx = starting_idx
|
|
|
+
|
|
|
+ def __call__(self, input_ids: torch.LongTensor,
|
|
|
+ _scores: torch.FloatTensor) -> bool:
|
|
|
+ for sample in input_ids:
|
|
|
+ trimmed_sample = sample[self.starting_idx:]
|
|
|
+ # Can't unfold, output is still too tiny. Skip.
|
|
|
+ if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ for window in trimmed_sample.unfold(
|
|
|
+ 0, self.sentinel_token_ids.shape[-1], 1):
|
|
|
+ if torch.all(torch.eq(self.sentinel_token_ids, window)):
|
|
|
+ return True
|
|
|
+ return False
|