瀏覽代碼

Add --rwkv-strategy parameter

oobabooga 2 年之前
父節點
當前提交
a2a3e8f797
共有 2 個文件被更改,包括 5 次插入1 次删除
  1. 4 1
      modules/RWKV.py
  2. 1 0
      modules/shared.py

+ 4 - 1
modules/RWKV.py

@@ -25,7 +25,10 @@ class RWKVModel:
     def from_pretrained(self, path, dtype="fp16", device="cuda"):
         tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
 
-        model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
+        if shared.args.rwkv_strategy is None:
+            model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
+        else:
+            model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
         pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
 
         result = self()

+ 1 - 0
modules/shared.py

@@ -63,6 +63,7 @@ parser.add_argument("--compress-weight", action="store_true", help="FlexGen: act
 parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
 parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
 parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".')
 parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
 parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')