RWKV.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from pathlib import Path
  3. import numpy as np
  4. from tokenizers import Tokenizer
  5. import modules.shared as shared
  6. from modules.callbacks import Iteratorize
  7. np.set_printoptions(precision=4, suppress=True, linewidth=200)
  8. os.environ['RWKV_JIT_ON'] = '1'
  9. os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
  10. from rwkv.model import RWKV
  11. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  12. class RWKVModel:
  13. def __init__(self):
  14. pass
  15. @classmethod
  16. def from_pretrained(self, path, dtype="fp16", device="cuda"):
  17. tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
  18. if shared.args.rwkv_strategy is None:
  19. model = RWKV(model=str(path), strategy=f'{device} {dtype}')
  20. else:
  21. model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
  22. pipeline = PIPELINE(model, str(tokenizer_path))
  23. result = self()
  24. result.pipeline = pipeline
  25. return result
  26. def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
  27. args = PIPELINE_ARGS(
  28. temperature=temperature,
  29. top_p=top_p,
  30. top_k=top_k,
  31. alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
  32. alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
  33. token_ban=token_ban, # ban the generation of some tokens
  34. token_stop=token_stop
  35. )
  36. return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
  37. def generate_with_streaming(self, **kwargs):
  38. with Iteratorize(self.generate, kwargs, callback=None) as generator:
  39. reply = ''
  40. for token in generator:
  41. reply += token
  42. yield reply
  43. class RWKVTokenizer:
  44. def __init__(self):
  45. pass
  46. @classmethod
  47. def from_pretrained(self, path):
  48. tokenizer_path = path / "20B_tokenizer.json"
  49. tokenizer = Tokenizer.from_file(str(tokenizer_path))
  50. result = self()
  51. result.tokenizer = tokenizer
  52. return result
  53. def encode(self, prompt):
  54. return self.tokenizer.encode(prompt).ids
  55. def decode(self, ids):
  56. return self.tokenizer.decode(ids)