浏览代码

Give API extension access to all generate_reply parameters (#744)

* Make every parameter of the generate_reply function parameterizable

* Add stopping strings as parameterizable
Niels Mündler 2 年之前
父节点
当前提交
7aab88bcc6
共有 1 个文件被更改,包括 9 次插入8 次删除
  1. 9 8
      extensions/api/script.py

+ 9 - 8
extensions/api/script.py

@@ -44,20 +44,21 @@ class Handler(BaseHTTPRequestHandler):
             generator = generate_reply(
                 question = prompt, 
                 max_new_tokens = int(body.get('max_length', 200)), 
-                do_sample=True, 
+                do_sample=bool(body.get('do_sample', True)),
                 temperature=float(body.get('temperature', 0.5)), 
                 top_p=float(body.get('top_p', 1)), 
                 typical_p=float(body.get('typical', 1)), 
                 repetition_penalty=float(body.get('rep_pen', 1.1)), 
                 encoder_repetition_penalty=1, 
                 top_k=int(body.get('top_k', 0)), 
-                min_length=0, 
-                no_repeat_ngram_size=0, 
-                num_beams=1, 
-                penalty_alpha=0, 
-                length_penalty=1,
-                early_stopping=False,
-                seed=-1,
+                min_length=int(body.get('min_length', 0)),
+                no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
+                num_beams=int(body.get('num_beams',1)),
+                penalty_alpha=float(body.get('penalty_alpha', 0)),
+                length_penalty=float(body.get('length_penalty', 1)),
+                early_stopping=bool(body.get('early_stopping', False)),
+                seed=int(body.get('seed', -1)),
+                stopping_strings=body.get('stopping_strings', []),
             )
 
             answer = ''