Procházet zdrojové kódy

Update to use new llamacpp API

Thomas Antony před 2 roky
rodič
revize
7fa5d96c22
1 změnil soubory, kde provedl 14 přidání a 27 odebrání
  1. 14 27
      modules/llamacpp_model.py

+ 14 - 27
modules/llamacpp_model.py

@@ -8,16 +8,16 @@ import llamacpp
 
 class LlamaCppTokenizer:
     """A thin wrapper over the llamacpp tokenizer"""
-    def __init__(self, model: llamacpp.PyLLAMA):
+    def __init__(self, model: llamacpp.LlamaInference):
         self._tokenizer = model.get_tokenizer()
         self.eos_token_id = 2
         self.bos_token_id = 0
 
     @classmethod
-    def from_model(cls, model: llamacpp.PyLLAMA):
+    def from_model(cls, model: llamacpp.LlamaInference):
         return cls(model)
 
-    def encode(self, prompt):
+    def encode(self, prompt: str):
         return self._tokenizer.tokenize(prompt)
 
     def decode(self, ids):
@@ -30,21 +30,10 @@ class LlamaCppModel:
 
     @classmethod
     def from_pretrained(self, path):
-        params = llamacpp.gpt_params(
-            str(path),  # model
-            2048,  # ctx_size
-            200,  # n_predict
-            40,  # top_k
-            0.95,  # top_p
-            0.80,  # temp
-            1.30,  # repeat_penalty
-            -1,  # seed
-            8,  # threads
-            64,  # repeat_last_n
-            8,  # batch_size
-        )
-
-        _model = llamacpp.PyLLAMA(params)
+        params = llamacpp.InferenceParams()
+        params.path_model = str(path)
+
+        _model = llamacpp.LlamaInference(params)
 
         result = self()
         result.model = _model
@@ -63,22 +52,20 @@ class LlamaCppModel:
         # params.repeat_last_n = repeat_last_n
 
         # model.params = params
-        if not self.initialized:
-            self.model.add_bos()
-
+        self.model.add_bos()
         self.model.update_input(context)
-        if not self.initialized:
-            self.model.prepare_context()
-            self.initialized = True
 
         output = ""
         is_end_of_text = False
         ctr = 0
-        while not self.model.is_finished() and ctr < num_tokens and not is_end_of_text:
+        while ctr < num_tokens and not is_end_of_text:
             if self.model.has_unconsumed_input():
-                self.model.ingest_all_pending_input(False)
+                self.model.ingest_all_pending_input()
             else:
-                text, is_end_of_text = self.model.infer_text()
+                self.model.eval()
+                token = self.model.sample()
+                text = self.model.token_to_str(token)
+                is_end_of_text = token == self.model.token_eos()
                 if callback:
                     callback(text)
                 output += text