oobabooga 2 лет назад
Родитель
Сommit
8c8e8b4450
1 измененных файлов с 1 добавлено и 1 удалено
  1. 1 1
      modules/callbacks.py

+ 1 - 1
modules/callbacks.py

@@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
                 if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
                     continue
                 for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
-                    if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
+                    if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
                         return True
         return False