Explorar o código

Fix the early stopping callback #559

oobabooga %!s(int64=2) %!d(string=hai) anos
pai
achega
8c8e8b4450
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      modules/callbacks.py

+ 1 - 1
modules/callbacks.py

@@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
                 if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
                     continue
                 for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
-                    if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
+                    if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
                         return True
         return False