llamacpp_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. from pathlib import Path
  3. import modules.shared as shared
  4. from modules.callbacks import Iteratorize
  5. import llamacpp
  6. class LlamaCppTokenizer:
  7. """A thin wrapper over the llamacpp tokenizer"""
  8. def __init__(self, model: llamacpp.LlamaInference):
  9. self._tokenizer = model.get_tokenizer()
  10. self.eos_token_id = 2
  11. self.bos_token_id = 0
  12. @classmethod
  13. def from_model(cls, model: llamacpp.LlamaInference):
  14. return cls(model)
  15. def encode(self, prompt: str):
  16. return self._tokenizer.tokenize(prompt)
  17. def decode(self, ids):
  18. return self._tokenizer.detokenize(ids)
  19. class LlamaCppModel:
  20. def __init__(self):
  21. self.initialized = False
  22. @classmethod
  23. def from_pretrained(self, path):
  24. params = llamacpp.InferenceParams()
  25. params.path_model = str(path)
  26. _model = llamacpp.LlamaInference(params)
  27. result = self()
  28. result.model = _model
  29. tokenizer = LlamaCppTokenizer.from_model(_model)
  30. return result, tokenizer
  31. # TODO: Allow passing in params for each inference
  32. def generate(self, context="", num_tokens=10, callback=None):
  33. # params = self.params
  34. # params.n_predict = token_count
  35. # params.top_p = top_p
  36. # params.top_k = top_k
  37. # params.temp = temperature
  38. # params.repeat_penalty = repetition_penalty
  39. # params.repeat_last_n = repeat_last_n
  40. # model.params = params
  41. self.model.add_bos()
  42. self.model.update_input(context)
  43. output = ""
  44. is_end_of_text = False
  45. ctr = 0
  46. while ctr < num_tokens and not is_end_of_text:
  47. if self.model.has_unconsumed_input():
  48. self.model.ingest_all_pending_input()
  49. else:
  50. self.model.eval()
  51. token = self.model.sample()
  52. text = self.model.token_to_str(token)
  53. is_end_of_text = token == self.model.token_eos()
  54. if callback:
  55. callback(text)
  56. output += text
  57. ctr += 1
  58. return output
  59. def generate_with_streaming(self, **kwargs):
  60. with Iteratorize(self.generate, kwargs, callback=None) as generator:
  61. reply = kwargs['context']
  62. for token in generator:
  63. reply += token
  64. yield reply