From 76b599d467bbfc74b735172469f6fc9a3d790088 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 8 Oct 2024 15:48:48 -0700 Subject: [PATCH] fix pascal loading error --- keras_cv/src/datasets/pascal_voc/segmentation.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/keras_cv/src/datasets/pascal_voc/segmentation.py b/keras_cv/src/datasets/pascal_voc/segmentation.py index 678acb6926..424faaabc4 100644 --- a/keras_cv/src/datasets/pascal_voc/segmentation.py +++ b/keras_cv/src/datasets/pascal_voc/segmentation.py @@ -40,7 +40,7 @@ class and instance segmentation masks. import os.path import random import tarfile -import xml +from xml.etree import ElementTree import numpy as np import tensorflow as tf @@ -152,18 +152,21 @@ def _download_data_file( if not local_dir_path: # download to ~/.keras/datasets/fname cache_dir = os.path.join(os.path.expanduser("~"), ".keras/datasets") - fname = os.path.join(cache_dir, os.path.basename(data_url)) + fname = os.path.join(os.path.basename(data_url)) else: # Make sure the directory exists if not os.path.exists(local_dir_path): os.makedirs(local_dir_path, exist_ok=True) # download to local_dir_path/fname - fname = os.path.join(local_dir_path, os.path.basename(data_url)) + fname = os.path.join(os.path.basename(data_url)) + cache_dir = local_dir_path data_directory = os.path.join(os.path.dirname(fname), extracted_dir) if not override_extract and os.path.exists(data_directory): logging.info("data directory %s already exist", data_directory) return data_directory - data_file_path = keras.utils.get_file(fname=fname, origin=data_url) + data_file_path = keras.utils.get_file( + fname=fname, origin=data_url, cache_dir=cache_dir + ) # Extra the data into the same directory as the tar file. data_directory = os.path.dirname(data_file_path) logging.info("Extract data into %s", data_directory) @@ -180,7 +183,7 @@ def _parse_annotation_data(annotation_file_path): """ with tf.io.gfile.GFile(annotation_file_path, "r") as f: - root = xml.etree.ElementTree.parse(f).getroot() + root = ElementTree.parse(f).getroot() size = root.find("size") width = int(size.find("width").text)