Parcourir la source

Add --rwkv-cuda-on parameter, bump rwkv version

oobabooga il y a 2 ans
Parent
commit
153dfeb4dd
3 fichiers modifiés avec 4 ajouts et 3 suppressions
  1. 1 1
      modules/RWKV.py
  2. 2 1
      modules/shared.py
  3. 1 1
      requirements.txt

+ 1 - 1
modules/RWKV.py

@@ -9,7 +9,7 @@ import modules.shared as shared
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 
 os.environ['RWKV_JIT_ON'] = '1'
-os.environ["RWKV_CUDA_ON"] = '0' #  '1' : use CUDA kernel for seq mode (much faster)
+os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
 
 from rwkv.model import RWKV
 from rwkv.utils import PIPELINE, PIPELINE_ARGS

+ 2 - 1
modules/shared.py

@@ -81,7 +81,8 @@ parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, defaul
 parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
 parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
 parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
-parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".')
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
 parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
 parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')

+ 1 - 1
requirements.txt

@@ -3,7 +3,7 @@ bitsandbytes==0.37.0
 flexgen==0.1.7
 gradio==3.18.0
 numpy
-rwkv==0.0.7
+rwkv==0.0.8
 safetensors==0.2.8
 sentencepiece
 git+https://github.com/oobabooga/transformers@llama_push