Procházet zdrojové kódy

Move bot_picture.py inside the extension

oobabooga před 2 roky
rodič
revize
91f5852245
3 změnil soubory, kde provedl 15 přidání a 18 odebrání
  1. 10 4
      extensions/send_pictures/script.py
  2. 0 10
      modules/bot_picture.py
  3. 5 4
      modules/extensions.py

+ 10 - 4
extensions/send_pictures/script.py

@@ -2,13 +2,11 @@ import base64
 from io import BytesIO
 
 import gradio as gr
+import torch
+from transformers import BlipForConditionalGeneration, BlipProcessor
 
 import modules.chat as chat
 import modules.shared as shared
-from modules.bot_picture import caption_image
-
-params = {
-}
 
 # If 'state' is True, will hijack the next chat generation with
 # custom input text
@@ -17,6 +15,14 @@ input_hijack = {
     'value': ["", ""]
 }
 
+processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
+model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
+
+def caption_image(raw_image):
+    inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
+    out = model.generate(**inputs, max_new_tokens=100)
+    return processor.decode(out[0], skip_special_tokens=True)
+
 def generate_chat_picture(picture, name1, name2):
     text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
     buffer = BytesIO()

+ 0 - 10
modules/bot_picture.py

@@ -1,10 +0,0 @@
-import torch
-from transformers import BlipForConditionalGeneration, BlipProcessor
-
-processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
-model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
-
-def caption_image(raw_image):
-    inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
-    out = model.generate(**inputs, max_new_tokens=100)
-    return processor.decode(out[0], skip_special_tokens=True)

+ 5 - 4
modules/extensions.py

@@ -33,10 +33,11 @@ def apply_extensions(text, typ):
 def create_extensions_block():
     # Updating the default values
     for extension, name in iterator():
-        for param in extension.params:
-            _id = f"{name}-{param}"
-            if _id in shared.settings:
-                extension.params[param] = shared.settings[_id]
+        if hasattr(extension, 'params'):
+            for param in extension.params:
+                _id = f"{name}-{param}"
+                if _id in shared.settings:
+                    extension.params[param] = shared.settings[_id]
 
     # Creating the extension ui elements
     for extension, name in iterator():