Merge branch 'main' into da3dsoul-main
This commit is contained in:
@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
|
||||
'num_beams': int(body.get('num_beams',1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
}
|
||||
|
||||
generator = generate_reply(
|
||||
question = prompt,
|
||||
max_new_tokens = int(body.get('max_length', 200)),
|
||||
do_sample=bool(body.get('do_sample', True)),
|
||||
temperature=float(body.get('temperature', 0.5)),
|
||||
top_p=float(body.get('top_p', 1)),
|
||||
typical_p=float(body.get('typical', 1)),
|
||||
repetition_penalty=float(body.get('rep_pen', 1.1)),
|
||||
encoder_repetition_penalty=1,
|
||||
top_k=int(body.get('top_k', 0)),
|
||||
min_length=int(body.get('min_length', 0)),
|
||||
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
|
||||
num_beams=int(body.get('num_beams',1)),
|
||||
penalty_alpha=float(body.get('penalty_alpha', 0)),
|
||||
length_penalty=float(body.get('length_penalty', 1)),
|
||||
early_stopping=bool(body.get('early_stopping', False)),
|
||||
seed=int(body.get('seed', -1)),
|
||||
prompt,
|
||||
generate_params,
|
||||
stopping_strings=body.get('stopping_strings', []),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules.chat import load_character
|
||||
from modules.html_generator import get_image_cache
|
||||
from modules.shared import gradio, settings
|
||||
from modules.shared import gradio
|
||||
|
||||
|
||||
def generate_css():
|
||||
@@ -64,22 +63,13 @@ def generate_html():
|
||||
for file in sorted(Path("characters").glob("*")):
|
||||
if file.suffix in [".json", ".yml", ".yaml"]:
|
||||
character = file.stem
|
||||
container_html = f'<div class="character-container">'
|
||||
container_html = '<div class="character-container">'
|
||||
image_html = "<div class='placeholder'></div>"
|
||||
|
||||
for i in [
|
||||
f"characters/{character}.png",
|
||||
f"characters/{character}.jpg",
|
||||
f"characters/{character}.jpeg",
|
||||
]:
|
||||
|
||||
path = Path(i)
|
||||
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
||||
if path.exists():
|
||||
try:
|
||||
image_html = f'<img src="file/{get_image_cache(path)}">'
|
||||
break
|
||||
except:
|
||||
continue
|
||||
image_html = f'<img src="file/{get_image_cache(path)}">'
|
||||
break
|
||||
|
||||
container_html += f'{image_html} <span class="character-name">{character}</span>'
|
||||
container_html += "</div>"
|
||||
|
||||
@@ -176,4 +176,4 @@ def ui():
|
||||
|
||||
force_btn.click(force_pic)
|
||||
generate_now_btn.click(force_pic)
|
||||
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
@@ -2,12 +2,11 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation with
|
||||
# custom input text given by 'value' in the format [text, visible_text]
|
||||
input_hijack = {
|
||||
@@ -36,13 +35,11 @@ def generate_chat_picture(picture, name1, name2):
|
||||
def ui():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
|
||||
# Prepare the hijack with custom inputs
|
||||
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
||||
|
||||
# Call the generation function
|
||||
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||
|
||||
Reference in New Issue
Block a user