Create new API
This commit is contained in:
38
modules/api.py
Normal file
38
modules/api.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reply_wrapper(string):
|
||||||
|
generate_params = {
|
||||||
|
'do_sample': True,
|
||||||
|
'temperature': 1,
|
||||||
|
'top_p': 1,
|
||||||
|
'typical_p': 1,
|
||||||
|
'repetition_penalty': 1,
|
||||||
|
'encoder_repetition_penalty': 1,
|
||||||
|
'top_k': 50,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'min_length': 0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'early_stopping': False,
|
||||||
|
}
|
||||||
|
params = json.loads(string)
|
||||||
|
for k in params[1]:
|
||||||
|
generate_params[k] = params[1][k]
|
||||||
|
for i in generate_reply(params[0], generate_params):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def create_apis():
|
||||||
|
t1 = gr.Textbox(visible=False)
|
||||||
|
t2 = gr.Textbox(visible=False)
|
||||||
|
dummy = gr.Button(visible=False)
|
||||||
|
|
||||||
|
input_params = [t1]
|
||||||
|
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
|
||||||
|
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
|
||||||
@@ -15,7 +15,7 @@ import gradio as gr
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, shared, training, ui
|
from modules import chat, shared, training, ui, api
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt
|
||||||
@@ -538,6 +538,8 @@ def create_interface():
|
|||||||
else:
|
else:
|
||||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||||
|
|
||||||
|
api.create_apis()
|
||||||
|
|
||||||
# Authentication
|
# Authentication
|
||||||
auth = None
|
auth = None
|
||||||
if shared.args.gradio_auth_path is not None:
|
if shared.args.gradio_auth_path is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user