llamacpp_model.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from pathlib import Path
  2. import llamacpp
  3. import modules.shared as shared
  4. from modules.callbacks import Iteratorize
  5. class LlamaCppTokenizer:
  6. """A thin wrapper over the llamacpp tokenizer"""
  7. def __init__(self, model: llamacpp.LlamaInference):
  8. self._tokenizer = model.get_tokenizer()
  9. self.eos_token_id = 2
  10. self.bos_token_id = 0
  11. @classmethod
  12. def from_model(cls, model: llamacpp.LlamaInference):
  13. return cls(model)
  14. def encode(self, prompt: str):
  15. return self._tokenizer.tokenize(prompt)
  16. def decode(self, ids):
  17. return self._tokenizer.detokenize(ids)
  18. class LlamaCppModel:
  19. def __init__(self):
  20. self.initialized = False
  21. @classmethod
  22. def from_pretrained(self, path):
  23. params = llamacpp.InferenceParams()
  24. params.path_model = str(path)
  25. _model = llamacpp.LlamaInference(params)
  26. result = self()
  27. result.model = _model
  28. result.params = params
  29. tokenizer = LlamaCppTokenizer.from_model(_model)
  30. return result, tokenizer
  31. def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
  32. params = self.params
  33. params.n_predict = token_count
  34. params.top_p = top_p
  35. params.top_k = top_k
  36. params.temp = temperature
  37. params.repeat_penalty = repetition_penalty
  38. #params.repeat_last_n = repeat_last_n
  39. # model.params = params
  40. self.model.add_bos()
  41. self.model.update_input(context)
  42. output = ""
  43. is_end_of_text = False
  44. ctr = 0
  45. while ctr < token_count and not is_end_of_text:
  46. if self.model.has_unconsumed_input():
  47. self.model.ingest_all_pending_input()
  48. else:
  49. self.model.eval()
  50. token = self.model.sample()
  51. text = self.model.token_to_str(token)
  52. is_end_of_text = token == self.model.token_eos()
  53. if callback:
  54. callback(text)
  55. ctr += 1
  56. return output
  57. def generate_with_streaming(self, **kwargs):
  58. with Iteratorize(self.generate, kwargs, callback=None) as generator:
  59. reply = ''
  60. for token in generator:
  61. reply += token
  62. yield reply