oobabooga 3 年之前
父节点
当前提交
0b2a6b2819
共有 1 个文件被更改,包括 31 次插入0 次删除
  1. 31 0
      modules/stopping_criteria.py

+ 31 - 0
modules/stopping_criteria.py

@@ -0,0 +1,31 @@
+'''
+This code was copied from
+
+https://github.com/PygmalionAI/gradio-ui/
+
+'''
+
+import torch
+import transformers
+
+class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
+
+    def __init__(self, sentinel_token_ids: torch.LongTensor,
+                 starting_idx: int):
+        transformers.StoppingCriteria.__init__(self)
+        self.sentinel_token_ids = sentinel_token_ids
+        self.starting_idx = starting_idx
+
+    def __call__(self, input_ids: torch.LongTensor,
+                 _scores: torch.FloatTensor) -> bool:
+        for sample in input_ids:
+            trimmed_sample = sample[self.starting_idx:]
+            # Can't unfold, output is still too tiny. Skip.
+            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
+                continue
+
+            for window in trimmed_sample.unfold(
+                    0, self.sentinel_token_ids.shape[-1], 1):
+                if torch.all(torch.eq(self.sentinel_token_ids, window)):
+                    return True
+        return False