RWKV.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. from pathlib import Path
  3. from queue import Queue
  4. from threading import Thread
  5. import numpy as np
  6. from tokenizers import Tokenizer
  7. import modules.shared as shared
  8. from modules.callbacks import Iteratorize
  9. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  10. os.environ['RWKV_JIT_ON'] = '1'
  11. os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
  12. from rwkv.model import RWKV
  13. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  14. class RWKVModel:
  15. def __init__(self):
  16. pass
  17. @classmethod
  18. def from_pretrained(self, path, dtype="fp16", device="cuda"):
  19. tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
  20. if shared.args.rwkv_strategy is None:
  21. model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
  22. else:
  23. model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
  24. pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
  25. result = self()
  26. result.pipeline = pipeline
  27. return result
  28. def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
  29. args = PIPELINE_ARGS(
  30. temperature = temperature,
  31. top_p = top_p,
  32. top_k = top_k,
  33. alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
  34. alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
  35. token_ban = token_ban, # ban the generation of some tokens
  36. token_stop = token_stop
  37. )
  38. return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
  39. def generate_with_streaming(self, **kwargs):
  40. iterable = Iteratorize(self.generate, kwargs, callback=None)
  41. reply = kwargs['context']
  42. for token in iterable:
  43. reply += token
  44. yield reply
  45. class RWKVTokenizer:
  46. def __init__(self):
  47. pass
  48. @classmethod
  49. def from_pretrained(self, path):
  50. tokenizer_path = path / "20B_tokenizer.json"
  51. tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
  52. result = self()
  53. result.tokenizer = tokenizer
  54. return result
  55. def encode(self, prompt):
  56. return self.tokenizer.encode(prompt).ids
  57. def decode(self, ids):
  58. return self.tokenizer.decode(ids)