RWKV.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. from pathlib import Path
  3. from queue import Queue
  4. from threading import Thread
  5. import numpy as np
  6. from tokenizers import Tokenizer
  7. import modules.shared as shared
  8. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  9. os.environ['RWKV_JIT_ON'] = '1'
  10. os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
  11. from rwkv.model import RWKV
  12. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  13. class RWKVModel:
  14. def __init__(self):
  15. pass
  16. @classmethod
  17. def from_pretrained(self, path, dtype="fp16", device="cuda"):
  18. tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
  19. if shared.args.rwkv_strategy is None:
  20. model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
  21. else:
  22. model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
  23. pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
  24. result = self()
  25. result.pipeline = pipeline
  26. return result
  27. def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
  28. args = PIPELINE_ARGS(
  29. temperature = temperature,
  30. top_p = top_p,
  31. top_k = top_k,
  32. alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
  33. alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
  34. token_ban = token_ban, # ban the generation of some tokens
  35. token_stop = token_stop
  36. )
  37. return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
  38. def generate_with_streaming(self, **kwargs):
  39. iterable = Iteratorize(self.generate, kwargs, callback=None)
  40. reply = kwargs['context']
  41. for token in iterable:
  42. reply += token
  43. yield reply
  44. class RWKVTokenizer:
  45. def __init__(self):
  46. pass
  47. @classmethod
  48. def from_pretrained(self, path):
  49. tokenizer_path = path / "20B_tokenizer.json"
  50. tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
  51. result = self()
  52. result.tokenizer = tokenizer
  53. return result
  54. def encode(self, prompt):
  55. return self.tokenizer.encode(prompt).ids
  56. def decode(self, ids):
  57. return self.tokenizer.decode(ids)
  58. class Iteratorize:
  59. """
  60. Transforms a function that takes a callback
  61. into a lazy iterator (generator).
  62. """
  63. def __init__(self, func, kwargs={}, callback=None):
  64. self.mfunc=func
  65. self.c_callback=callback
  66. self.q = Queue(maxsize=1)
  67. self.sentinel = object()
  68. self.kwargs = kwargs
  69. def _callback(val):
  70. self.q.put(val)
  71. def gentask():
  72. ret = self.mfunc(callback=_callback, **self.kwargs)
  73. self.q.put(self.sentinel)
  74. if self.c_callback:
  75. self.c_callback(ret)
  76. Thread(target=gentask).start()
  77. def __iter__(self):
  78. return self
  79. def __next__(self):
  80. obj = self.q.get(True,None)
  81. if obj is self.sentinel:
  82. raise StopIteration
  83. else:
  84. return obj