|
|
@@ -16,23 +16,17 @@ from pathlib import Path
|
|
|
|
|
|
import requests
|
|
|
import tqdm
|
|
|
+from tqdm.contrib.concurrent import thread_map
|
|
|
|
|
|
-def get_file(args):
|
|
|
- url = args[0]
|
|
|
- output_folder = args[1]
|
|
|
- idx = args[2]
|
|
|
- tot = args[3]
|
|
|
-
|
|
|
- print(f"Downloading file {idx} of {tot}...")
|
|
|
+def get_file(url, output_folder):
|
|
|
r = requests.get(url, stream=True)
|
|
|
- with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
|
|
+ with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
|
|
|
total_size = int(r.headers.get('content-length', 0))
|
|
|
block_size = 1024
|
|
|
- t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
|
|
|
- for data in r.iter_content(block_size):
|
|
|
- t.update(len(data))
|
|
|
- f.write(data)
|
|
|
- t.close()
|
|
|
+ with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
|
|
+ for data in r.iter_content(block_size):
|
|
|
+ t.update(len(data))
|
|
|
+ f.write(data)
|
|
|
|
|
|
def sanitize_branch_name(branch_name):
|
|
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
|
@@ -143,13 +137,8 @@ def get_download_links_from_huggingface(model, branch):
|
|
|
|
|
|
return links, is_lora
|
|
|
|
|
|
-def download_files(file_list, output_folder, num_processes=8):
|
|
|
- with multiprocessing.Pool(processes=num_processes) as pool:
|
|
|
- args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)]
|
|
|
- for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)):
|
|
|
- pass
|
|
|
- pool.close()
|
|
|
- pool.join()
|
|
|
+def download_files(file_list, output_folder, num_threads=8):
|
|
|
+ thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
@@ -187,4 +176,4 @@ if __name__ == '__main__':
|
|
|
|
|
|
# Downloading the files
|
|
|
print(f"Downloading the model to {output_folder}")
|
|
|
- download_files(links, output_folder, num_processes=args.threads)
|
|
|
+ download_files(links, output_folder, args.threads)
|