script.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  3. from threading import Thread
  4. from modules import shared
  5. from modules.text_generation import encode, generate_reply
  6. params = {
  7. 'port': 5000,
  8. }
  9. class Handler(BaseHTTPRequestHandler):
  10. def do_GET(self):
  11. if self.path == '/api/v1/model':
  12. self.send_response(200)
  13. self.end_headers()
  14. response = json.dumps({
  15. 'result': shared.model_name
  16. })
  17. self.wfile.write(response.encode('utf-8'))
  18. else:
  19. self.send_error(404)
  20. def do_POST(self):
  21. content_length = int(self.headers['Content-Length'])
  22. body = json.loads(self.rfile.read(content_length).decode('utf-8'))
  23. if self.path == '/api/v1/generate':
  24. self.send_response(200)
  25. self.send_header('Content-Type', 'application/json')
  26. self.end_headers()
  27. prompt = body['prompt']
  28. prompt_lines = [l.strip() for l in prompt.split('\n')]
  29. max_context = body.get('max_context_length', 2048)
  30. while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
  31. prompt_lines.pop(0)
  32. prompt = '\n'.join(prompt_lines)
  33. generate_params = {
  34. 'max_new_tokens': int(body.get('max_length', 200)),
  35. 'do_sample': bool(body.get('do_sample', True)),
  36. 'temperature': float(body.get('temperature', 0.5)),
  37. 'top_p': float(body.get('top_p', 1)),
  38. 'typical_p': float(body.get('typical', 1)),
  39. 'repetition_penalty': float(body.get('rep_pen', 1.1)),
  40. 'encoder_repetition_penalty': 1,
  41. 'top_k': int(body.get('top_k', 0)),
  42. 'min_length': int(body.get('min_length', 0)),
  43. 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
  44. 'num_beams': int(body.get('num_beams',1)),
  45. 'penalty_alpha': float(body.get('penalty_alpha', 0)),
  46. 'length_penalty': float(body.get('length_penalty', 1)),
  47. 'early_stopping': bool(body.get('early_stopping', False)),
  48. 'seed': int(body.get('seed', -1)),
  49. }
  50. generator = generate_reply(
  51. prompt,
  52. generate_params,
  53. stopping_strings=body.get('stopping_strings', []),
  54. )
  55. answer = ''
  56. for a in generator:
  57. if isinstance(a, str):
  58. answer = a
  59. else:
  60. answer = a[0]
  61. response = json.dumps({
  62. 'results': [{
  63. 'text': answer[len(prompt):]
  64. }]
  65. })
  66. self.wfile.write(response.encode('utf-8'))
  67. else:
  68. self.send_error(404)
  69. def run_server():
  70. server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
  71. server = ThreadingHTTPServer(server_addr, Handler)
  72. if shared.args.share:
  73. try:
  74. from flask_cloudflared import _run_cloudflared
  75. public_url = _run_cloudflared(params['port'], params['port'] + 1)
  76. print(f'Starting KoboldAI compatible api at {public_url}/api')
  77. except ImportError:
  78. print('You should install flask_cloudflared manually')
  79. else:
  80. print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
  81. server.serve_forever()
  82. def setup():
  83. Thread(target=run_server, daemon=True).start()