llamacpp_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import llamacpp
  2. from modules.callbacks import Iteratorize
  3. class LlamaCppTokenizer:
  4. """A thin wrapper over the llamacpp tokenizer"""
  5. def __init__(self, model: llamacpp.LlamaInference):
  6. self._tokenizer = model.get_tokenizer()
  7. self.eos_token_id = 2
  8. self.bos_token_id = 0
  9. @classmethod
  10. def from_model(cls, model: llamacpp.LlamaInference):
  11. return cls(model)
  12. def encode(self, prompt: str):
  13. return self._tokenizer.tokenize(prompt)
  14. def decode(self, ids):
  15. return self._tokenizer.detokenize(ids)
  16. class LlamaCppModel:
  17. def __init__(self):
  18. self.initialized = False
  19. @classmethod
  20. def from_pretrained(self, path):
  21. params = llamacpp.InferenceParams()
  22. params.path_model = str(path)
  23. _model = llamacpp.LlamaInference(params)
  24. result = self()
  25. result.model = _model
  26. result.params = params
  27. tokenizer = LlamaCppTokenizer.from_model(_model)
  28. return result, tokenizer
  29. def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
  30. params = self.params
  31. params.n_predict = token_count
  32. params.top_p = top_p
  33. params.top_k = top_k
  34. params.temp = temperature
  35. params.repeat_penalty = repetition_penalty
  36. #params.repeat_last_n = repeat_last_n
  37. # model.params = params
  38. self.model.add_bos()
  39. self.model.update_input(context)
  40. output = ""
  41. is_end_of_text = False
  42. ctr = 0
  43. while ctr < token_count and not is_end_of_text:
  44. if self.model.has_unconsumed_input():
  45. self.model.ingest_all_pending_input()
  46. else:
  47. self.model.eval()
  48. token = self.model.sample()
  49. text = self.model.token_to_str(token)
  50. output += text
  51. is_end_of_text = token == self.model.token_eos()
  52. if callback:
  53. callback(text)
  54. ctr += 1
  55. return output
  56. def generate_with_streaming(self, **kwargs):
  57. with Iteratorize(self.generate, kwargs, callback=None) as generator:
  58. reply = ''
  59. for token in generator:
  60. reply += token
  61. yield reply