|
|
@@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|
|
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
|
|
continue
|
|
|
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
|
|
- if torch.all(torch.eq(self.sentinel_token_ids[i], window)):
|
|
|
+ if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
|
|
return True
|
|
|
return False
|
|
|
|