oobabooga пре 2 година
родитељ
комит
302e3b7973
1 измењених фајлова са 19 додато и 16 уклоњено
  1. 19 16
      extensions/api/script.py

+ 19 - 16
extensions/api/script.py

@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
                 prompt_lines.pop(0)
 
             prompt = '\n'.join(prompt_lines)
+            generate_params =  {
+                'max_new_tokens': int(body.get('max_length', 200)), 
+                'do_sample': bool(body.get('do_sample', True)),
+                'temperature': float(body.get('temperature', 0.5)), 
+                'top_p': float(body.get('top_p', 1)), 
+                'typical_p': float(body.get('typical', 1)), 
+                'repetition_penalty': float(body.get('rep_pen', 1.1)), 
+                'encoder_repetition_penalty': 1,
+                'top_k': int(body.get('top_k', 0)), 
+                'min_length': int(body.get('min_length', 0)),
+                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
+                'num_beams': int(body.get('num_beams',1)),
+                'penalty_alpha': float(body.get('penalty_alpha', 0)),
+                'length_penalty': float(body.get('length_penalty', 1)),
+                'early_stopping': bool(body.get('early_stopping', False)),
+                'seed': int(body.get('seed', -1)),
+            }
 
             generator = generate_reply(
-                question = prompt, 
-                max_new_tokens = int(body.get('max_length', 200)), 
-                do_sample=bool(body.get('do_sample', True)),
-                temperature=float(body.get('temperature', 0.5)), 
-                top_p=float(body.get('top_p', 1)), 
-                typical_p=float(body.get('typical', 1)), 
-                repetition_penalty=float(body.get('rep_pen', 1.1)), 
-                encoder_repetition_penalty=1, 
-                top_k=int(body.get('top_k', 0)), 
-                min_length=int(body.get('min_length', 0)),
-                no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
-                num_beams=int(body.get('num_beams',1)),
-                penalty_alpha=float(body.get('penalty_alpha', 0)),
-                length_penalty=float(body.get('length_penalty', 1)),
-                early_stopping=bool(body.get('early_stopping', False)),
-                seed=int(body.get('seed', -1)),
+                prompt, 
+                generate_params,
                 stopping_strings=body.get('stopping_strings', []),
             )