script.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  2. from threading import Thread
  3. from modules import shared
  4. from modules.text_generation import generate_reply, encode
  5. import json
  6. params = {
  7. 'port': 5000,
  8. }
  9. server = None
  10. class Handler(BaseHTTPRequestHandler):
  11. def do_GET(self):
  12. if self.path == '/api/v1/model':
  13. self.send_response(200)
  14. self.end_headers()
  15. response = json.dumps({
  16. 'result': shared.model_name
  17. })
  18. self.wfile.write(response.encode('utf-8'))
  19. else:
  20. self.send_error(404)
  21. def do_POST(self):
  22. content_length = int(self.headers['Content-Length'])
  23. body = json.loads(self.rfile.read(content_length).decode('utf-8'))
  24. if self.path == '/api/v1/generate':
  25. self.send_response(200)
  26. self.send_header('Content-Type', 'application/json')
  27. self.end_headers()
  28. prompt = body['prompt']
  29. prompt_lines = [l.strip() for l in prompt.split('\n')]
  30. max_context = body.get('max_context_length', 2048)
  31. while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
  32. prompt_lines.pop(0)
  33. prompt = '\n'.join(prompt_lines)
  34. generator = generate_reply(
  35. question = prompt,
  36. max_new_tokens = body.get('max_length', 200),
  37. do_sample=True,
  38. temperature=body.get('temperature', 0.5),
  39. top_p=body.get('top_p', 1),
  40. typical_p=body.get('typical', 1),
  41. repetition_penalty=body.get('rep_pen', 1.1),
  42. encoder_repetition_penalty=1,
  43. top_k=body.get('top_k', 0),
  44. min_length=0,
  45. no_repeat_ngram_size=0,
  46. num_beams=1,
  47. penalty_alpha=0,
  48. length_penalty=1,
  49. early_stopping=False,
  50. )
  51. answer = ''
  52. for a in generator:
  53. answer = a[0]
  54. response = json.dumps({
  55. 'results': [{
  56. 'text': answer[len(prompt):]
  57. }]
  58. })
  59. self.wfile.write(response.encode('utf-8'))
  60. else:
  61. self.send_error(404)
  62. def run_server():
  63. global server
  64. server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
  65. server = ThreadingHTTPServer(server_addr, Handler)
  66. if shared.args.share:
  67. try:
  68. from flask_cloudflared import _run_cloudflared
  69. public_url = _run_cloudflared(params['port'], params['port'] + 1)
  70. print(f'Starting KoboldAI compatible api at {public_url}/api')
  71. except ImportError:
  72. print('You should install flask_cloudflared manually')
  73. else:
  74. print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
  75. server.serve_forever()
  76. def ui():
  77. if server is None:
  78. Thread(target=run_server, daemon=True).start()