Jelajahi Sumber

Upload profile pictures from the web UI

oobabooga 3 tahun lalu
induk
melakukan
f71531186b
1 mengubah file dengan 28 tambahan dan 6 penghapusan
  1. 28 6
      server.py

+ 28 - 6
server.py

@@ -5,9 +5,11 @@ import glob
 import torch
 import argparse
 import json
+import io
 import sys
 from sys import exit
 from pathlib import Path
+from PIL import Image
 import copy
 import gradio as gr
 import warnings
@@ -504,19 +506,27 @@ if args.chat or args.cai_chat:
         else:
             return name2, context, history['visible']
 
-    def upload_character(file, name1, name2):
-        file = file.decode('utf-8')
-        data = json.loads(file)
+    def upload_character(json_file, img, name1, name2):
+        json_file = json_file.decode('utf-8')
+        data = json.loads(json_file)
         outfile_name = data["char_name"]
         i = 1
         while Path(f'characters/{outfile_name}.json').exists():
             outfile_name = f'{data["char_name"]}_{i:03d}'
             i += 1
         with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
-            f.write(file)
+            f.write(json_file)
+        if img is not None:
+            img = Image.open(io.BytesIO(img)).convert('RGB')
+            img.save(Path(f'characters/{outfile_name}.jpg'))
         print(f'New character saved to "characters/{outfile_name}.json".')
         return outfile_name
 
+    def upload_your_profile_picture(img):
+        img = Image.open(io.BytesIO(img)).convert('RGB')
+        img.save(Path(f'img_me.jpg'))
+        print(f'Profile picture saved to "img_me.jpg"')
+
     suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
     with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
         if args.cai_chat:
@@ -559,7 +569,16 @@ if args.chat or args.cai_chat:
                 download = gr.File()
                 save_btn = gr.Button(value="Click me")
             with gr.Tab('Upload character'):
-                upload_char = gr.File(type='binary')
+                with gr.Row():
+                    with gr.Column():
+                        gr.Markdown('1. Select the JSON file')
+                        upload_char = gr.File(type='binary')
+                    with gr.Column():
+                        gr.Markdown('2. Select your character\'s profile picture (optional)')
+                        upload_img = gr.File(type='binary')
+                upload_btn = gr.Button(value="Submit")
+            with gr.Tab('Upload your profile picture'):
+                upload_img_me = gr.File(type='binary')
 
         input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider]
         if args.cai_chat:
@@ -579,12 +598,15 @@ if args.chat or args.cai_chat:
         save_btn.click(save_history, inputs=[], outputs=[download])
         character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
         upload.upload(upload_history, [upload, name1, name2], [])
-        upload_char.upload(upload_character, [upload_char, name1, name2], [character_menu])
+        upload_btn.click(upload_character, [upload_char, upload_img, name1, name2], [character_menu])
+        upload_img_me.upload(upload_your_profile_picture, [upload_img_me], [])
 
         if args.cai_chat:
             upload.upload(redraw_html, [name1, name2], [display1])
+            upload_img_me.upload(redraw_html, [name1, name2], [display1])
         else:
             upload.upload(lambda : history['visible'], [], [display1])
+            upload_img_me.upload(lambda : history['visible'], [], [display1])
 
 elif args.notebook:
     with gr.Blocks(css=css, analytics_enabled=False) as interface: