diff --git a/docs/examples/maxar_open_data.ipynb b/docs/examples/maxar_open_data.ipynb index 347e6bcc..2c64e89a 100644 --- a/docs/examples/maxar_open_data.ipynb +++ b/docs/examples/maxar_open_data.ipynb @@ -63,9 +63,7 @@ }, "outputs": [], "source": [ - "url = (\n", - " \"https://drive.google.com/file/d/1jIIC5hvSPeJEC0fbDhtxVWk2XV9AxsQD/view?usp=sharing\"\n", - ")" + "url = \"https://github.com/opengeos/datasets/releases/download/raster/Derna_sample.tif\"" ] }, { diff --git a/requirements.txt b/requirements.txt index 07cb902d..fefe4324 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ leafmap localtileserver matplotlib opencv-python +patool pycocotools pyproj rasterio diff --git a/samgeo/common.py b/samgeo/common.py index ed3fb40a..d990e590 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -209,19 +209,22 @@ def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False): model_types = { "vit_h": { "name": "sam_hq_vit_h.pth", - "url": "https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing", + "url": [ + "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.zip", + "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.z01", + ], }, "vit_l": { "name": "sam_hq_vit_l.pth", - "url": "https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing", + "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_l.pth", }, "vit_b": { "name": "sam_hq_vit_b.pth", - "url": "https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing", + "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_b.pth", }, "vit_tiny": { "name": "sam_hq_vit_tiny.pth", - "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth", + "url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_tiny.pth", }, } @@ -239,7 +242,10 @@ def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False): if not os.path.exists(checkpoint): print(f"Model checkpoint for {model_type} not found.") url = model_types[model_type]["url"] - download_file(url, checkpoint) + if isinstance(url, str): + download_file(url, checkpoint) + elif isinstance(url, list): + download_files(url, checkpoint_dir, multi_part=True) return checkpoint @@ -2987,3 +2993,136 @@ def merge_rasters( dstNodata=output_nodata, options=output_options, ) + + +def extract_archive(archive, outdir=None, **kwargs): + """ + Extracts a multipart archive. + + This function uses the patoolib library to extract a multipart archive. + If the patoolib library is not installed, it attempts to install it. + If the archive does not end with ".zip", it appends ".zip" to the archive name. + If the extraction fails (for example, if the files already exist), it skips the extraction. + + Args: + archive (str): The path to the archive file. + outdir (str): The directory where the archive should be extracted. + **kwargs: Arbitrary keyword arguments for the patoolib.extract_archive function. + + Returns: + None + + Raises: + Exception: An exception is raised if the extraction fails for reasons other than the files already existing. + + Example: + + files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"] + base_url = "https://github.com/opengeos/datasets/releases/download/models/" + urls = [base_url + f for f in files] + leafmap.download_files(urls, out_dir="models", multi_part=True) + + """ + try: + import patoolib + except ImportError: + install_package("patool") + import patoolib + + if not archive.endswith(".zip"): + archive = archive + ".zip" + + if outdir is None: + outdir = os.path.dirname(archive) + + try: + patoolib.extract_archive(archive, outdir=outdir, **kwargs) + except Exception as e: + print("The unzipped files might already exist. Skipping extraction.") + return + + +def download_files( + urls, + out_dir=None, + filenames=None, + quiet=False, + proxy=None, + speed=None, + use_cookies=True, + verify=True, + id=None, + fuzzy=False, + resume=False, + unzip=True, + overwrite=False, + subfolder=False, + multi_part=False, +): + """Download files from URLs, including Google Drive shared URL. + + Args: + urls (list): The list of urls to download. Google Drive URL is also supported. + out_dir (str, optional): The output directory. Defaults to None. + filenames (list, optional): Output filename. Default is basename of URL. + quiet (bool, optional): Suppress terminal output. Default is False. + proxy (str, optional): Proxy. Defaults to None. + speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None. + use_cookies (bool, optional): Flag to use cookies. Defaults to True. + verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True. + id (str, optional): Google Drive's file ID. Defaults to None. + fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False. + resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False. + unzip (bool, optional): Unzip the file. Defaults to True. + overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False. + subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False. + multi_part (bool, optional): If the file is a multi-part file. Defaults to False. + + Examples: + + files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"] + base_url = "https://github.com/opengeos/datasets/releases/download/models/" + urls = [base_url + f for f in files] + leafmap.download_files(urls, out_dir="models", multi_part=True) + """ + + if out_dir is None: + out_dir = os.getcwd() + + if filenames is None: + filenames = [None] * len(urls) + + filepaths = [] + for url, output in zip(urls, filenames): + if output is None: + filename = os.path.join(out_dir, os.path.basename(url)) + else: + filename = os.path.join(out_dir, output) + + filepaths.append(filename) + if multi_part: + unzip = False + + download_file( + url, + filename, + quiet, + proxy, + speed, + use_cookies, + verify, + id, + fuzzy, + resume, + unzip, + overwrite, + subfolder, + ) + + if multi_part: + archive = os.path.splitext(filename)[0] + ".zip" + out_dir = os.path.dirname(filename) + extract_archive(archive, out_dir) + + for file in filepaths: + os.remove(file) diff --git a/samgeo/fast_sam.py b/samgeo/fast_sam.py index 86609899..5dfc9b54 100644 --- a/samgeo/fast_sam.py +++ b/samgeo/fast_sam.py @@ -30,8 +30,8 @@ def __init__(self, model="FastSAM-x.pt", **kwargs): ) models = { - "FastSAM-x.pt": "https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing", - "FastSAM-s.pt": "https://drive.google.com/file/d/10XmSj6mmpmRb8NhXbtiuO9cTTBwR_9SV/view?usp=sharing", + "FastSAM-x.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-x.pt", + "FastSAM-s.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-s.pt", } if model not in models: