api.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import json
  2. import gradio as gr
  3. from modules import shared
  4. from modules.text_generation import generate_reply
  5. def generate_reply_wrapper(string):
  6. generate_params = {
  7. 'do_sample': True,
  8. 'temperature': 1,
  9. 'top_p': 1,
  10. 'typical_p': 1,
  11. 'repetition_penalty': 1,
  12. 'encoder_repetition_penalty': 1,
  13. 'top_k': 50,
  14. 'num_beams': 1,
  15. 'penalty_alpha': 0,
  16. 'min_length': 0,
  17. 'length_penalty': 1,
  18. 'no_repeat_ngram_size': 0,
  19. 'early_stopping': False,
  20. }
  21. params = json.loads(string)
  22. for k in params[1]:
  23. generate_params[k] = params[1][k]
  24. for i in generate_reply(params[0], generate_params):
  25. yield i
  26. def create_apis():
  27. t1 = gr.Textbox(visible=False)
  28. t2 = gr.Textbox(visible=False)
  29. dummy = gr.Button(visible=False)
  30. input_params = [t1]
  31. output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
  32. dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')