瀏覽代碼

Fix the early stopping callback #559

oobabooga 2 年之前
父節點
當前提交
8c8e8b4450
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      modules/callbacks.py

+ 1 - 1
modules/callbacks.py

@@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
                 if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
                 if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
                     continue
                     continue
                 for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
                 for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
-                    if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
+                    if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
                         return True
                         return True
         return False
         return False