oobabooga 2 лет назад
Родитель
Сommit
e91f4bc25a
3 измененных файлов с 34 добавлено и 15 удалено
  1. 20 0
      modules/RWKV.py
  2. 3 2
      modules/models.py
  3. 11 13
      modules/text_generation.py

+ 20 - 0
modules/RWKV.py

@@ -2,6 +2,7 @@ import os
 from pathlib import Path
 
 import numpy as np
+from tokenizers import Tokenizer
 
 import modules.shared as shared
 
@@ -43,3 +44,22 @@ class RWKVModel:
         )
 
         return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
+
+class RWKVTokenizer:
+    def __init__(self):
+        pass
+
+    @classmethod
+    def from_pretrained(self, path):
+        tokenizer_path = path / "20B_tokenizer.json"
+        tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
+
+        result = self()
+        result.tokenizer = tokenizer
+        return result
+
+    def encode(self, prompt):
+        return self.tokenizer.encode(prompt).ids
+
+    def decode(self, ids):
+        return self.tokenizer.decode(ids)

+ 3 - 2
modules/models.py

@@ -79,11 +79,12 @@ def load_model(model_name):
 
     # RMKV model (not on HuggingFace)
     elif shared.is_RWKV:
-        from modules.RWKV import RWKVModel
+        from modules.RWKV import RWKVModel, RWKVTokenizer
 
         model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+        tokenizer = RWKVTokenizer.from_pretrained(Path('models'))
 
-        return model, None
+        return model, tokenizer
 
     # Custom
     else:

+ 11 - 13
modules/text_generation.py

@@ -21,21 +21,19 @@ def get_max_prompt_length(tokens):
     return max_length
 
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
-
-    # These models do not have explicit tokenizers for now, so
-    # we return an estimate for the number of tokens
     if shared.is_RWKV:
-        return np.zeros((1, len(prompt)//4))
-
-    input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
-    if shared.args.cpu:
-        return input_ids
-    elif shared.args.flexgen:
-        return input_ids.numpy()
-    elif shared.args.deepspeed:
-        return input_ids.to(device=local_rank)
+        input_ids = shared.tokenizer.encode(str(prompt))
+        input_ids = np.array(input_ids).reshape(1, len(input_ids))
     else:
-        return input_ids.cuda()
+        input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+        if shared.args.cpu:
+            return input_ids
+        elif shared.args.flexgen:
+            return input_ids.numpy()
+        elif shared.args.deepspeed:
+            return input_ids.to(device=local_rank)
+        else:
+            return input_ids.cuda()
 
 def decode(output_ids):
     reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)