|
@@ -108,10 +108,7 @@ def load_model(model_name):
|
|
|
|
|
|
|
|
# Default settings
|
|
# Default settings
|
|
|
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
|
|
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
|
|
|
- if Path(f"torch-dumps/{model_name}.pt").exists():
|
|
|
|
|
- print("Loading in .pt format...")
|
|
|
|
|
- model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
|
|
|
|
|
- elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
|
|
|
|
|
|
+ if model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')):
|
|
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
|
|
|
else:
|
|
else:
|
|
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
|
|
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
|
|
@@ -425,7 +422,7 @@ def update_extensions_parameters(*kwargs):
|
|
|
i += 1
|
|
i += 1
|
|
|
|
|
|
|
|
def get_available_models():
|
|
def get_available_models():
|
|
|
- return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
|
|
|
|
|
|
+ return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith('.txt')], key=lambda x: x.lower())
|
|
|
|
|
|
|
|
def get_available_presets():
|
|
def get_available_presets():
|
|
|
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
|
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|