| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import json
- from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
- from threading import Thread
- from modules import shared
- from modules.text_generation import encode, generate_reply
- params = {
- 'port': 5000,
- }
- class Handler(BaseHTTPRequestHandler):
- def do_GET(self):
- if self.path == '/api/v1/model':
- self.send_response(200)
- self.end_headers()
- response = json.dumps({
- 'result': shared.model_name
- })
- self.wfile.write(response.encode('utf-8'))
- else:
- self.send_error(404)
- def do_POST(self):
- content_length = int(self.headers['Content-Length'])
- body = json.loads(self.rfile.read(content_length).decode('utf-8'))
- if self.path == '/api/v1/generate':
- self.send_response(200)
- self.send_header('Content-Type', 'application/json')
- self.end_headers()
- prompt = body['prompt']
- prompt_lines = [k.strip() for k in prompt.split('\n')]
- max_context = body.get('max_context_length', 2048)
- while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
- prompt_lines.pop(0)
- prompt = '\n'.join(prompt_lines)
- generate_params = {
- 'max_new_tokens': int(body.get('max_length', 200)),
- 'do_sample': bool(body.get('do_sample', True)),
- 'temperature': float(body.get('temperature', 0.5)),
- 'top_p': float(body.get('top_p', 1)),
- 'typical_p': float(body.get('typical', 1)),
- 'repetition_penalty': float(body.get('rep_pen', 1.1)),
- 'encoder_repetition_penalty': 1,
- 'top_k': int(body.get('top_k', 0)),
- 'min_length': int(body.get('min_length', 0)),
- 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
- 'num_beams': int(body.get('num_beams', 1)),
- 'penalty_alpha': float(body.get('penalty_alpha', 0)),
- 'length_penalty': float(body.get('length_penalty', 1)),
- 'early_stopping': bool(body.get('early_stopping', False)),
- 'seed': int(body.get('seed', -1)),
- 'add_bos_token': int(body.get('add_bos_token', True)),
- 'custom_stopping_strings': body.get('custom_stopping_strings', []),
- 'truncation_length': int(body.get('truncation_length', 2048)),
- 'ban_eos_token': bool(body.get('ban_eos_token', False)),
- }
- generator = generate_reply(
- prompt,
- generate_params,
- )
- answer = ''
- for a in generator:
- if isinstance(a, str):
- answer = a
- else:
- answer = a[0]
- response = json.dumps({
- 'results': [{
- 'text': answer[len(prompt):]
- }]
- })
- self.wfile.write(response.encode('utf-8'))
- else:
- self.send_error(404)
- def run_server():
- server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
- server = ThreadingHTTPServer(server_addr, Handler)
- if shared.args.share:
- try:
- from flask_cloudflared import _run_cloudflared
- public_url = _run_cloudflared(params['port'], params['port'] + 1)
- print(f'Starting KoboldAI compatible api at {public_url}/api')
- except ImportError:
- print('You should install flask_cloudflared manually')
- else:
- print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
- server.serve_forever()
- def setup():
- Thread(target=run_server, daemon=True).start()
|