Browse Source

Improve the imports

oobabooga 2 năm trước cách đây
mục cha
commit
7224343a70

+ 3 - 4
convert-to-flexgen.py

@@ -3,6 +3,7 @@
 Converts a transformers model to a format compatible with flexgen.
 
 '''
+
 import argparse
 import os
 from pathlib import Path
@@ -10,9 +11,8 @@ from pathlib import Path
 import numpy as np
 import torch
 from tqdm import tqdm
-from transformers import AutoModelForCausalLM
-from transformers import AutoTokenizer
- 
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 args = parser.parse_args()
@@ -31,7 +31,6 @@ def disable_torch_init():
     torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
     setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
 
-
 def restore_torch_init():
     """Rollback the change made by disable_torch_init."""
     import torch

+ 3 - 3
convert-to-safetensors.py

@@ -10,13 +10,13 @@ Based on the original script by 81300:
 https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
 
 '''
+
 import argparse
 from pathlib import Path
 
 import torch
-from transformers import AutoModelForCausalLM
-from transformers import AutoTokenizer
- 
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')

+ 1 - 2
modules/bot_picture.py

@@ -1,6 +1,5 @@
 import torch
-from transformers import BlipForConditionalGeneration
-from transformers import BlipProcessor
+from transformers import BlipForConditionalGeneration, BlipProcessor
 
 processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")

+ 3 - 4
modules/chat.py

@@ -7,13 +7,12 @@ from datetime import datetime
 from io import BytesIO
 from pathlib import Path
 
+from PIL import Image
+
 import modules.shared as shared
 from modules.extensions import apply_extensions
 from modules.html_generator import generate_chat_html
-from modules.text_generation import encode
-from modules.text_generation import generate_reply
-from modules.text_generation import get_max_prompt_length
-from PIL import Image
+from modules.text_generation import encode, generate_reply, get_max_prompt_length
 
 if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
     import modules.bot_picture as bot_picture

+ 1 - 2
modules/extensions.py

@@ -1,6 +1,5 @@
-import modules.shared as shared
-
 import extensions
+import modules.shared as shared
 
 extension_state = {}
 available_extensions = []

+ 1 - 0
modules/html_generator.py

@@ -3,6 +3,7 @@
 This is a library for formatting GPT-4chan and chat outputs as nice HTML.
 
 '''
+
 import base64
 import os
 import re

+ 9 - 5
modules/models.py

@@ -4,23 +4,27 @@ import time
 import zipfile
 from pathlib import Path
 
-import modules.shared as shared
 import numpy as np
 import torch
 import transformers
-from transformers import AutoModelForCausalLM
-from transformers import AutoTokenizer
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+import modules.shared as shared
 
 transformers.logging.set_verbosity_error()
 
 local_rank = None
 
 if shared.args.flexgen:
-    from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config)
+    from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
+                                  TorchDevice, TorchDisk, TorchMixedDevice,
+                                  get_opt_config)
 
 if shared.args.deepspeed:
     import deepspeed
-    from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
+    from transformers.deepspeed import (HfDeepSpeedConfig,
+                                        is_deepspeed_zero3_enabled)
+
     from modules.deepspeed_parameters import generate_ds_config
 
     # Distributed setup

+ 2 - 0
modules/stopping_criteria.py

@@ -4,9 +4,11 @@ This code was copied from
 https://github.com/PygmalionAI/gradio-ui/
 
 '''
+
 import torch
 import transformers
 
+
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 
     def __init__(self, sentinel_token_ids: torch.LongTensor,

+ 5 - 4
modules/text_generation.py

@@ -1,16 +1,17 @@
 import re
 import time
 
-import modules.shared as shared
 import numpy as np
 import torch
 import transformers
+from tqdm import tqdm
+
+import modules.shared as shared
 from modules.extensions import apply_extensions
-from modules.html_generator import generate_4chan_html
-from modules.html_generator import generate_basic_html
+from modules.html_generator import generate_4chan_html, generate_basic_html
 from modules.models import local_rank
 from modules.stopping_criteria import _SentinelTokenStoppingCriteria
-from tqdm import tqdm
+
 
 def get_max_prompt_length(tokens):
     max_length = 2048-tokens

+ 2 - 5
server.py

@@ -14,12 +14,9 @@ import modules.chat as chat
 import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.ui as ui
-from modules.extensions import extension_state
-from modules.extensions import load_extensions
-from modules.extensions import update_extensions_parameters
+from modules.extensions import extension_state, load_extensions, update_extensions_parameters
 from modules.html_generator import generate_chat_html
-from modules.models import load_model
-from modules.models import load_soft_prompt
+from modules.models import load_model, load_soft_prompt
 from modules.text_generation import generate_reply
 
 if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: