Explorar o código

Add arg for bfloat16

81300 %!s(int64=3) %!d(string=hai) anos
pai
achega
a6f4760772
Modificáronse 1 ficheiros con 15 adicións e 5 borrados
  1. 15 5
      server.py

+ 15 - 5
server.py

@@ -37,6 +37,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
 parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
 parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
 parser.add_argument('--nvme-offload-dir', type=str, help='Directory to use for DeepSpeed ZeRO-3 NVME offloading.')
+parser.add_argument('--bf16', action='store_true', help='Instantiate the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--local_rank', type=int, default=0, help='Optional argument for DeepSpeed distributed setups.')
 parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
@@ -92,14 +93,20 @@ if args.deepspeed:
 
     # DeepSpeed configration
     # https://huggingface.co/docs/transformers/main_classes/deepspeed
+    if args.bf16:
+        ds_fp16 = False
+        ds_bf16 = True
+    else:
+        ds_fp16 = True
+        ds_bf16 = False
     train_batch_size = 1 * world_size
     if args.nvme_offload_dir:
         ds_config = {
             "fp16": {
-                "enabled": True,
+                "enabled": ds_fp16,
             },
             "bf16": {
-                "enabled": False,
+                "enabled": ds_bf16,
             },
             "zero_optimization": {
                 "stage": 3,
@@ -135,10 +142,10 @@ if args.deepspeed:
     else:
         ds_config = {
             "fp16": {
-                "enabled": True,
+                "enabled": ds_fp16,
             },
             "bf16": {
-                "enabled": False,
+                "enabled": ds_bf16,
             },
             "zero_optimization": {
                 "stage": 3,
@@ -178,7 +185,10 @@ def load_model(model_name):
 
     # DeepSpeed ZeRO-3
     elif args.deepspeed:
-        model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"))
+        if args.bf16:
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16)
+        else:
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.float16)
         model = deepspeed.initialize(model=model,
                                      config_params=ds_config,
                                      model_parameters=None,