ソースを参照

Create new API

oobabooga 2 年 前
コミット
9c3a585915
2 ファイル変更41 行追加1 行削除
  1. 38 0
      modules/api.py
  2. 3 1
      server.py

+ 38 - 0
modules/api.py

@@ -0,0 +1,38 @@
+import json
+
+import gradio as gr
+
+from modules import shared
+from modules.text_generation import generate_reply
+
+
+def generate_reply_wrapper(string):
+    generate_params = {
+        'do_sample': True,
+        'temperature': 1,
+        'top_p': 1,
+        'typical_p': 1,
+        'repetition_penalty': 1,
+        'encoder_repetition_penalty': 1,
+        'top_k': 50,
+        'num_beams': 1,
+        'penalty_alpha': 0,
+        'min_length': 0,
+        'length_penalty': 1,
+        'no_repeat_ngram_size': 0,
+        'early_stopping': False,
+    }
+    params = json.loads(string)
+    for k in params[1]:
+        generate_params[k] = params[1][k]
+    for i in generate_reply(params[0], generate_params):
+        yield i
+
+def create_apis():
+    t1 = gr.Textbox(visible=False)
+    t2 = gr.Textbox(visible=False)
+    dummy = gr.Button(visible=False)
+
+    input_params = [t1]
+    output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
+    dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

+ 3 - 1
server.py

@@ -15,7 +15,7 @@ import gradio as gr
 from PIL import Image
 
 import modules.extensions as extensions_module
-from modules import chat, shared, training, ui
+from modules import chat, shared, training, ui, api
 from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
@@ -538,6 +538,8 @@ def create_interface():
             else:
                 shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
 
+        api.create_apis()
+
     # Authentication
     auth = None
     if shared.args.gradio_auth_path is not None: