llamacpp_model.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. from pathlib import Path
  3. import modules.shared as shared
  4. from modules.callbacks import Iteratorize
  5. import llamacpp
  6. class LlamaCppTokenizer:
  7. """A thin wrapper over the llamacpp tokenizer"""
  8. def __init__(self, model: llamacpp.PyLLAMA):
  9. self._tokenizer = model.get_tokenizer()
  10. self.eos_token_id = 2
  11. self.bos_token_id = 0
  12. @classmethod
  13. def from_model(cls, model: llamacpp.PyLLAMA):
  14. return cls(model)
  15. def encode(self, prompt):
  16. return self._tokenizer.tokenize(prompt)
  17. def decode(self, ids):
  18. return self._tokenizer.detokenize(ids)
  19. class LlamaCppModel:
  20. def __init__(self):
  21. self.initialized = False
  22. @classmethod
  23. def from_pretrained(self, path):
  24. params = llamacpp.gpt_params(
  25. str(path), # model
  26. 2048, # ctx_size
  27. 200, # n_predict
  28. 40, # top_k
  29. 0.95, # top_p
  30. 0.80, # temp
  31. 1.30, # repeat_penalty
  32. -1, # seed
  33. 8, # threads
  34. 64, # repeat_last_n
  35. 8, # batch_size
  36. )
  37. _model = llamacpp.PyLLAMA(params)
  38. result = self()
  39. result.model = _model
  40. tokenizer = LlamaCppTokenizer.from_model(_model)
  41. return result, tokenizer
  42. # TODO: Allow passing in params for each inference
  43. def generate(self, context="", num_tokens=10, callback=None):
  44. # params = self.params
  45. # params.n_predict = token_count
  46. # params.top_p = top_p
  47. # params.top_k = top_k
  48. # params.temp = temperature
  49. # params.repeat_penalty = repetition_penalty
  50. # params.repeat_last_n = repeat_last_n
  51. # model.params = params
  52. if not self.initialized:
  53. self.model.add_bos()
  54. self.model.update_input(context)
  55. if not self.initialized:
  56. self.model.prepare_context()
  57. self.initialized = True
  58. output = ""
  59. is_end_of_text = False
  60. ctr = 0
  61. while not self.model.is_finished() and ctr < num_tokens and not is_end_of_text:
  62. if self.model.has_unconsumed_input():
  63. self.model.ingest_all_pending_input(False)
  64. else:
  65. text, is_end_of_text = self.model.infer_text()
  66. if callback:
  67. callback(text)
  68. output += text
  69. ctr += 1
  70. return output
  71. def generate_with_streaming(self, **kwargs):
  72. with Iteratorize(self.generate, kwargs, callback=None) as generator:
  73. reply = kwargs['context']
  74. for token in generator:
  75. reply += token
  76. yield reply