فهرست منبع

Add RWKVModel class

oobabooga 2 سال پیش
والد
کامیت
659bb76722
2فایلهای تغییر یافته به همراه18 افزوده شده و 7 حذف شده
  1. 14 5
      modules/RWKV.py
  2. 4 2
      modules/models.py

+ 14 - 5
modules/RWKV.py

@@ -16,11 +16,20 @@ os.environ["RWKV_CUDA_ON"] = '0' #  '1' : use CUDA kernel for seq mode (much fas
 from rwkv.model import RWKV
 from rwkv.model import RWKV
 from rwkv.utils import PIPELINE, PIPELINE_ARGS
 from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
 
+class RWKVModel:
+    def __init__(self):
+        pass
 
 
-def load_RWKV_model(path):
-    print(f'strategy={"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}')
+    @classmethod
+    def from_pretrained(self, path, dtype="fp16", device="cuda"):
+        tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
 
 
-    model = RWKV(model=path.as_posix(), strategy=f'{"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}')
-    pipeline = PIPELINE(model, Path("models/20B_tokenizer.json").as_posix())
+        model = RWKV(model=path.as_posix(), strategy=f'{device} {dtype}')
+        pipeline = PIPELINE(model, tokenizer_path.as_posix())
 
 
-    return pipeline
+        result = self()
+        result.model = pipeline
+        return result
+
+    def generate(self, context, **kwargs):
+        return self.model.generate(context, **kwargs)

+ 4 - 2
modules/models.py

@@ -79,9 +79,11 @@ def load_model(model_name):
 
 
     # RMKV model (not on HuggingFace)
     # RMKV model (not on HuggingFace)
     elif shared.is_RWKV:
     elif shared.is_RWKV:
-        from modules.RWKV import load_RWKV_model
+        from modules.RWKV import RWKVModel
 
 
-        return load_RWKV_model(Path(f'models/{model_name}')), None
+        model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+
+        return model, None
 
 
     # Custom
     # Custom
     else:
     else: