oobabooga пре 2 година
родитељ
комит
67ee7bead7
1 измењених фајлова са 11 додато и 10 уклоњено
  1. 11 10
      modules/RWKV.py

+ 11 - 10
modules/RWKV.py

@@ -1,6 +1,13 @@
-import os, time, types, torch
+import os
+import time
+import types
 from pathlib import Path
+
 import numpy as np
+import torch
+
+import modules.shared as shared
+
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 
 os.environ['RWKV_JIT_ON'] = '1'
@@ -10,17 +17,11 @@ import repositories.ChatRWKV.v2.rwkv as rwkv
 from rwkv.model import RWKV
 from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
-def load_RWKV_model(path):
-    os.system("ls")
-    model = RWKV(model=path.as_posix(), strategy='cuda fp16')
 
-    out, state = model.forward([187, 510, 1563, 310, 247], None)   # use 20B_tokenizer.json
-    print(out.detach().cpu().numpy())                   # get logits
-    out, state = model.forward([187, 510], None)
-    out, state = model.forward([1563], state)           # RNN has state (use deepcopy if you want to clone it)
-    out, state = model.forward([310, 247], state)
-    print(out.detach().cpu().numpy())                   # same result as above
+def load_RWKV_model(path):
+    print(f'strategy={"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}')
 
+    model = RWKV(model=path.as_posix(), strategy=f'{"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}')
     pipeline = PIPELINE(model, Path("repositories/ChatRWKV/20B_tokenizer.json").as_posix())
 
     return pipeline