Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pascal loading error #2496

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions keras_cv/src/datasets/pascal_voc/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class and instance segmentation masks.
import os.path
import random
import tarfile
import xml
from xml.etree import ElementTree

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -152,18 +152,21 @@ def _download_data_file(
if not local_dir_path:
# download to ~/.keras/datasets/fname
cache_dir = os.path.join(os.path.expanduser("~"), ".keras/datasets")
fname = os.path.join(cache_dir, os.path.basename(data_url))
fname = os.path.join(os.path.basename(data_url))
else:
# Make sure the directory exists
if not os.path.exists(local_dir_path):
os.makedirs(local_dir_path, exist_ok=True)
# download to local_dir_path/fname
fname = os.path.join(local_dir_path, os.path.basename(data_url))
fname = os.path.join(os.path.basename(data_url))
cache_dir = local_dir_path
data_directory = os.path.join(os.path.dirname(fname), extracted_dir)
if not override_extract and os.path.exists(data_directory):
logging.info("data directory %s already exist", data_directory)
return data_directory
data_file_path = keras.utils.get_file(fname=fname, origin=data_url)
data_file_path = keras.utils.get_file(
fname=fname, origin=data_url, cache_dir=cache_dir
)
# Extra the data into the same directory as the tar file.
data_directory = os.path.dirname(data_file_path)
logging.info("Extract data into %s", data_directory)
Expand All @@ -180,7 +183,7 @@ def _parse_annotation_data(annotation_file_path):

"""
with tf.io.gfile.GFile(annotation_file_path, "r") as f:
root = xml.etree.ElementTree.parse(f).getroot()
root = ElementTree.parse(f).getroot()

size = root.find("size")
width = int(size.find("width").text)
Expand Down
Loading