oobabooga 2 lat temu
rodzic
commit
3ac7c9c80a
2 zmienionych plików z 7 dodań i 32 usunięć
  1. 2 16
      api-example-stream.py
  2. 5 16
      api-example.py

+ 2 - 16
api-example-stream.py

@@ -36,6 +36,7 @@ async def run(context):
         'early_stopping': False,
         'seed': -1,
     }
+    payload = json.dumps([context, params])
     session = random_hash()
 
     async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -54,22 +55,7 @@ async def run(context):
                         "session_hash": session,
                         "fn_index": 12,
                         "data": [
-                            context,
-                            params['max_new_tokens'],
-                            params['do_sample'],
-                            params['temperature'],
-                            params['top_p'],
-                            params['typical_p'],
-                            params['repetition_penalty'],
-                            params['encoder_repetition_penalty'],
-                            params['top_k'],
-                            params['min_length'],
-                            params['no_repeat_ngram_size'],
-                            params['num_beams'],
-                            params['penalty_alpha'],
-                            params['length_penalty'],
-                            params['early_stopping'],
-                            params['seed'],
+                            payload
                         ]
                     }))
                 case "process_starts":

+ 5 - 16
api-example.py

@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
 allowing you to use the API remotely.
 
 '''
+import json
+
 import requests
 
 # Server address
@@ -38,24 +40,11 @@ params = {
 # Input prompt
 prompt = "What I would like to say is the following: "
 
+payload = json.dumps([prompt, params])
+
 response = requests.post(f"http://{server}:7860/run/textgen", json={
     "data": [
-        prompt,
-        params['max_new_tokens'],
-        params['do_sample'],
-        params['temperature'],
-        params['top_p'],
-        params['typical_p'],
-        params['repetition_penalty'],
-        params['encoder_repetition_penalty'],
-        params['top_k'],
-        params['min_length'],
-        params['no_repeat_ngram_size'],
-        params['num_beams'],
-        params['penalty_alpha'],
-        params['length_penalty'],
-        params['early_stopping'],
-        params['seed'],
+        payload
     ]
 }).json()