|
|
@@ -7,6 +7,7 @@ import numpy as np
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
import modules.shared as shared
|
|
|
+from modules.callbacks import Iteratorize
|
|
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
|
|
@@ -73,38 +74,3 @@ class RWKVTokenizer:
|
|
|
|
|
|
def decode(self, ids):
|
|
|
return self.tokenizer.decode(ids)
|
|
|
-
|
|
|
-class Iteratorize:
|
|
|
-
|
|
|
- """
|
|
|
- Transforms a function that takes a callback
|
|
|
- into a lazy iterator (generator).
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, func, kwargs={}, callback=None):
|
|
|
- self.mfunc=func
|
|
|
- self.c_callback=callback
|
|
|
- self.q = Queue(maxsize=1)
|
|
|
- self.sentinel = object()
|
|
|
- self.kwargs = kwargs
|
|
|
-
|
|
|
- def _callback(val):
|
|
|
- self.q.put(val)
|
|
|
-
|
|
|
- def gentask():
|
|
|
- ret = self.mfunc(callback=_callback, **self.kwargs)
|
|
|
- self.q.put(self.sentinel)
|
|
|
- if self.c_callback:
|
|
|
- self.c_callback(ret)
|
|
|
-
|
|
|
- Thread(target=gentask).start()
|
|
|
-
|
|
|
- def __iter__(self):
|
|
|
- return self
|
|
|
-
|
|
|
- def __next__(self):
|
|
|
- obj = self.q.get(True,None)
|
|
|
- if obj is self.sentinel:
|
|
|
- raise StopIteration
|
|
|
- else:
|
|
|
- return obj
|