|
|
@@ -18,12 +18,16 @@ import re
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('MODEL', type=str)
|
|
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
|
|
+parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
def get_file(args):
|
|
|
url = args[0]
|
|
|
output_folder = args[1]
|
|
|
+ idx = args[2]
|
|
|
+ tot = args[3]
|
|
|
|
|
|
+ print(f"Downloading file {idx} of {tot}...")
|
|
|
r = requests.get(url, stream=True)
|
|
|
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
|
|
|
total_size = int(r.headers.get('content-length', 0))
|
|
|
@@ -77,8 +81,8 @@ if __name__ == '__main__':
|
|
|
downloads.append(f'https://huggingface.co/{href}')
|
|
|
|
|
|
# Downloading the files
|
|
|
- print(f"Downloading the model to {output_folder}...")
|
|
|
- pool = multiprocessing.Pool(processes=4)
|
|
|
- results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))])
|
|
|
+ print(f"Downloading the model to {output_folder}")
|
|
|
+ pool = multiprocessing.Pool(processes=args.threads)
|
|
|
+ results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))])
|
|
|
pool.close()
|
|
|
pool.join()
|