Skip to content

Commit

Permalink
PathTraversal fix (#1159)
Browse files Browse the repository at this point in the history
* SAST PathTraversal fix

Signed-off-by: Akihiko Kuroda <[email protected]>
  • Loading branch information
akihikokuroda authored Jan 19, 2024
1 parent dea441a commit 7ce8151
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 11 deletions.
5 changes: 4 additions & 1 deletion gateway/api/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
decrypt_env_vars,
generate_cluster_name,
)
from utils import sanitize_file_path
from main import settings

logger = logging.getLogger("commands")
Expand Down Expand Up @@ -79,7 +80,9 @@ def submit(self, job: Job) -> Optional[str]:
_, dependencies = try_json_loads(program.dependencies)
with tarfile.open(program.artifact.path) as file:
extract_folder = os.path.join(
settings.MEDIA_ROOT, "tmp", str(uuid.uuid4())
sanitize_file_path(str(settings.MEDIA_ROOT)),
"tmp",
str(uuid.uuid4()),
)
file.extractall(extract_folder)

Expand Down
33 changes: 26 additions & 7 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from rest_framework.decorators import action
from rest_framework.generics import get_object_or_404
from rest_framework.response import Response
from utils import sanitize_file_path

from .exceptions import InternalServerErrorException, ResourceNotFoundException
from .models import Program, Job
Expand Down Expand Up @@ -361,7 +362,10 @@ def list(self, request):
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.list", context=ctx):
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
)
if os.path.exists(user_dir):
files = [
os.path.basename(path)
Expand Down Expand Up @@ -390,8 +394,13 @@ def download(self, request): # pylint: disable=invalid-name
if requested_file_name is not None:
# look for file in user's folder
filename = os.path.basename(requested_file_name)
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
file_path = os.path.join(user_dir, filename)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)

if os.path.exists(user_dir) and os.path.exists(file_path) and filename:
chunk_size = 8192
Expand Down Expand Up @@ -423,8 +432,13 @@ def delete(self, request): # pylint: disable=invalid-name
if request.data and "file" in request.data:
# look for file in user's folder
filename = os.path.basename(request.data["file"])
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
file_path = os.path.join(user_dir, filename)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)

if os.path.exists(user_dir) and os.path.exists(file_path) and filename:
os.remove(file_path)
Expand All @@ -442,8 +456,13 @@ def upload(self, request): # pylint: disable=invalid-name
with tracer.start_as_current_span("gateway.files.download", context=ctx):
upload_file = request.FILES["file"]
filename = os.path.basename(upload_file.name)
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
file_path = os.path.join(user_dir, filename)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(request.user.username),
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)
)
with open(file_path, "wb+") as destination:
for chunk in upload_file.chunks():
destination.write(chunk)
Expand Down
8 changes: 5 additions & 3 deletions gateway/main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import sys
from datetime import timedelta
from pathlib import Path
from utils import sanitize_file_path


# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
Expand Down Expand Up @@ -198,10 +200,10 @@
# https://docs.djangoproject.com/en/4.1/howto/static-files/

STATIC_URL = "static/"
STATIC_ROOT = os.path.join(BASE_DIR, "static")
STATIC_ROOT = os.path.join(sanitize_file_path(str(BASE_DIR)), "static")

MEDIA_URL = "media/"
MEDIA_ROOT = os.path.join(BASE_DIR, "media")
MEDIA_ROOT = os.path.join(sanitize_file_path(str(BASE_DIR)), "media")

# Default primary key field type
# https://docs.djangoproject.com/en/4.1/ref/settings/#default-auto-field
Expand Down Expand Up @@ -258,7 +260,7 @@
"REFRESH_TOKEN_LIFETIME": timedelta(days=20),
}

MEDIA_ROOT = os.path.join(BASE_DIR, "media")
MEDIA_ROOT = os.path.join(sanitize_file_path(str(BASE_DIR)), "media")
MEDIA_URL = "/media/"

# custom token auth
Expand Down
6 changes: 6 additions & 0 deletions gateway/tests/api/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_files_list(self):
"resources",
"fake_media",
)
media_root = os.path.normpath(os.path.join(os.getcwd(), media_root))

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_file_download(self):
"resources",
"fake_media",
)
media_root = os.path.normpath(os.path.join(os.getcwd(), media_root))

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
Expand All @@ -87,6 +89,8 @@ def test_file_delete(self):
"resources",
"fake_media",
)
media_root = os.path.normpath(os.path.join(os.getcwd(), media_root))

with open(
os.path.join(media_root, "test_user", "artifact_delete.tar"), "w"
) as fp:
Expand Down Expand Up @@ -115,6 +119,7 @@ def test_non_existing_file_delete(self):
"resources",
"fake_media",
)
media_root = os.path.normpath(os.path.join(os.getcwd(), media_root))

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
Expand All @@ -137,6 +142,7 @@ def test_file_upload(self):
"resources",
"fake_media",
)
media_root = os.path.normpath(os.path.join(os.getcwd(), media_root))

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
Expand Down
49 changes: 49 additions & 0 deletions gateway/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
===========================================================
Utilities (:mod:`quantum_serverless.utils.utils`)
===========================================================
.. currentmodule:: quantum_serverless.utils.utils
Quantum serverless utilities
====================================
.. autosummary::
:toctree: ../stubs/
utility functions
"""

import os
import re


def sanitize_file_path(path: str):
"""sanitize file path.
Sanitization:
character string '..' is replaced to '_'.
character except '0-9a-zA-Z-_.' and directory delimiter('/' or '\')
is replaced to '_'.
Args:
path: file path
Returns:
sanitized filepath
"""
if ".." in path:
path = path.replace("..", "_")
pattern = "[^0-9a-zA-Z-_." + os.sep + "]+"
return re.sub(pattern, "_", path)

0 comments on commit 7ce8151

Please sign in to comment.