@@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
continue
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
- if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
+ if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
return True
return False