Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YOLOv4 to studio #120

Merged
merged 33 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1a7a5ff
Initial commit - add object detection route
bgoelTT Dec 9, 2024
b3fd6df
Add package-lock.json
bgoelTT Dec 9, 2024
a983c61
Add two-column object detection component
bgoelTT Dec 9, 2024
dfd0574
Add new layout and component structure
bgoelTT Dec 10, 2024
4bf1cfa
Use Aceternity UI file picker
bgoelTT Dec 10, 2024
7629cb6
adds tabs to control menu
anirudTT Dec 11, 2024
ebb1431
modifies to move webcam to main component
anirudTT Dec 11, 2024
f23d704
adds webcam component
anirudTT Dec 11, 2024
1058657
add the react package for webcam util
anirudTT Dec 11, 2024
8bba82e
add shadcnn tabs ui component
anirudTT Dec 11, 2024
21b0122
modifies file upload to show last uploaded file + color change + alwa…
anirudTT Dec 11, 2024
57a70f1
Fix containing element scroll and z-stack
bgoelTT Dec 11, 2024
0a9aabc
Merge branch 'object-detection' of github.com:tenstorrent/tt-studio i…
bgoelTT Dec 11, 2024
7cf7cb8
Add overflow scroll to main component
bgoelTT Dec 11, 2024
286015f
Allow images to assume full width of ObjectDetectionComponent
bgoelTT Dec 11, 2024
da1dd2b
Add YoloV4 model config to backend API
bgoelTT Dec 12, 2024
1ea1dd2
Create new object-detection endpoint & expand DeviceConfigurations en…
bgoelTT Dec 16, 2024
10db9f8
Add ModelType enumeration in frontend to faciliate conditional naviga…
bgoelTT Dec 16, 2024
a71e7c6
WIP add components to support:
anirudTT Dec 16, 2024
f7a619b
draw box on image
anirudTT Dec 16, 2024
076bf24
remove
anirudTT Dec 16, 2024
dddfc6c
Merge commit 'a71e7c69161c83a60e5b2cdc77e52422be690cb2' into object-d…
bgoelTT Dec 17, 2024
aadb55f
Merge commit '076bf24462f0166054e870e0f6cfc7f443d87c3b' into object-d…
bgoelTT Dec 17, 2024
ea9e912
Optimize real-time object detection to prevent frame backlog
anirudTT Dec 18, 2024
6a804a6
Ensure webcam stops completely when stop button is clicked +
anirudTT Dec 18, 2024
00edc81
ts fixes
anirudTT Dec 18, 2024
922cfa5
Fix aspect ratio of video container to 4:3
bgoelTT Dec 19, 2024
db3cc36
Fix navigation and add <img> to SourcePicker component - TODO - wire …
bgoelTT Dec 20, 2024
210c063
Refactor inference API call and UI
bgoelTT Dec 20, 2024
54b5259
Fix UI bugs
bgoelTT Dec 20, 2024
cd5c6f7
Merge branch 'staging' of github.com:tenstorrent/tt-studio into objec…
bgoelTT Dec 23, 2024
927aadc
Add API authentication to YOLOv4 backend
bgoelTT Jan 2, 2025
9ed1d10
Address PR comments
bgoelTT Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/api/model_control/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
path("inference/", views.InferenceView.as_view()),
path("deployed/", views.DeployedModelsView.as_view()),
path("model_weights/", views.ModelWeightsView.as_view()),
path("object-detection/", views.ObjectDetectionInferenceView.as_view()),
path("health/", views.ModelHealthView.as_view()),
]
40 changes: 40 additions & 0 deletions app/api/model_control/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

# model_control/views.py
from pathlib import Path
import requests
from PIL import Image
import io

from rest_framework import status
from rest_framework.views import APIView
Expand All @@ -12,6 +15,7 @@

from .serializers import InferenceSerializer, ModelWeightsSerializer
from model_control.model_utils import (
encoded_jwt,
get_deploy_cache,
stream_response_from_external_api,
health_check,
Expand Down Expand Up @@ -97,3 +101,39 @@ def get(self, request, *args, **kwargs):
return Response(weights, status=status.HTTP_200_OK)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)


class ObjectDetectionInferenceView(APIView):
def post(self, request, *args, **kwargs):
"""special inference view that performs special handling"""
data = request.data
logger.info(f"InferenceView data:={data}")
serializer = InferenceSerializer(data=data)
if serializer.is_valid():
deploy_id = data.get("deploy_id")
image = data.get("image").file # we should only receive 1 file
deploy = get_deploy_cache()[deploy_id]
internal_url = "http://" + deploy["internal_url"]
# construct file to send
pil_image = Image.open(image)
pil_image = pil_image.resize((320, 320)) # Resize to target dimensions
buf = io.BytesIO()
pil_image.save(
buf,
format="JPEG",
)
byte_im = buf.getvalue()
file = {"file": byte_im}
try:
headers = {"Authorization": f"Bearer {encoded_jwt}"}
inference_data = requests.post(internal_url, files=file, headers=headers, timeout=5)
inference_data.raise_for_status()
except requests.exceptions.HTTPError as http_err:
if inference_data.status_code == status.HTTP_401_UNAUTHORIZED:
return Response(status=status.HTTP_401_UNAUTHORIZED)
else:
return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)

return Response(inference_data.json(), status=status.HTTP_200_OK)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
3 changes: 3 additions & 0 deletions app/api/shared_config/device_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@


class DeviceConfigurations(Enum):
"""The *WH_ARCH_YAML enumerations signal to use the wormhole_b0_80_arch_eth_dispatch.yaml"""
CPU = auto()
E150 = auto()
N150 = auto()
N150_WH_ARCH_YAML = auto()
milank94 marked this conversation as resolved.
Show resolved Hide resolved
N300x4 = auto()
N300x4_WH_ARCH_YAML = auto()


def detect_available_devices():
Expand Down
35 changes: 26 additions & 9 deletions app/api/shared_config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class ModelImpl:
model_id: str
image_name: str
image_tag: str
hf_model_path: str
device_configurations: Set["DeviceConfigurations"]
docker_config: Dict[str, Any]
user_uid: int # user inside docker container uid (for file permissions)
Expand All @@ -51,6 +50,7 @@ class ModelImpl:
service_route: str
env_file: str = ""
health_route: str = "/health"
hf_model_path: str = ""

def __post_init__(self):
self.docker_config.update({"volumes": self.get_volume_mounts()})
Expand All @@ -59,18 +59,22 @@ def __post_init__(self):
self.docker_config["environment"]["HF_HOME"] = Path(
backend_config.model_container_cache_root
).joinpath("huggingface")

# Set environment variable if N150 or N300x4 is in the device configurations
if DeviceConfigurations.N150 in self.device_configurations or DeviceConfigurations.N300x4 in self.device_configurations:
self.docker_config["environment"]["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

# Set environment variable if N150_WH_ARCH_YAML or N300x4_WH_ARCH_YAML is in the device configurations
milank94 marked this conversation as resolved.
Show resolved Hide resolved
if (
DeviceConfigurations.N150_WH_ARCH_YAML in self.device_configurations
or DeviceConfigurations.N300x4_WH_ARCH_YAML in self.device_configurations
):
self.docker_config["environment"]["WH_ARCH_YAML"] = (
"wormhole_b0_80_arch_eth_dispatch.yaml"
)

if self.env_file:
logger.info(f"Using env file: {self.env_file}")
# env file should be in persistent volume mounted
env_dict = load_dotenv_dict(self.env_file)
# env file overrides any existing docker environment variables
self.docker_config["environment"].update(env_dict)


@property
def image_version(self) -> str:
Expand Down Expand Up @@ -155,6 +159,19 @@ def base_docker_config():
# model_ids are unique strings to define a model, they could be uuids but
# using friendly strings prefixed with id_ is more helpful for debugging
model_implmentations_list = [
ModelImpl(
model_name="YOLOv4",
model_id="id_yolov4v0.0.1",
image_name="ghcr.io/tenstorrent/tt-inference-server/tt-metal-yolov4-src-base",
image_tag="v0.0.1-tt-metal-65d246482b3f",
device_configurations={DeviceConfigurations.N150},
docker_config=base_docker_config(),
user_uid=1000,
user_gid=1000,
shm_size="32G",
service_port=7000,
service_route="/objdetection_v2",
),
ModelImpl(
model_name="Mock-Llama-3.1-70B-Instruct",
model_id="id_mock_vllm_modelv0.0.1",
Expand All @@ -174,8 +191,8 @@ def base_docker_config():
model_id="id_tt-metal-falcon-7bv0.0.13",
image_name="tt-metal-falcon-7b",
image_tag="v0.0.13",
device_configurations={DeviceConfigurations.N150_WH_ARCH_YAML},
hf_model_path="tiiuae/falcon-7b-instruct",
device_configurations={DeviceConfigurations.N150},
docker_config=base_docker_config(),
user_uid=1000,
user_gid=1000,
Expand All @@ -189,7 +206,7 @@ def base_docker_config():
image_name="ghcr.io/tenstorrent/tt-inference-server/tt-metal-llama3-70b-src-base-vllm",
image_tag="v0.0.3-tt-metal-385904186f81-384f1790c3be",
hf_model_path="meta-llama/Llama-3.1-70B-Instruct",
device_configurations={DeviceConfigurations.N300x4},
device_configurations={DeviceConfigurations.N300x4_WH_ARCH_YAML},
docker_config=base_docker_config(),
user_uid=1000,
user_gid=1000,
Expand All @@ -204,7 +221,7 @@ def base_docker_config():
image_name="ghcr.io/tenstorrent/tt-inference-server/tt-metal-mistral-7b-src-base",
image_tag="v0.0.3-tt-metal-v0.52.0-rc33",
hf_model_path="mistralai/Mistral-7B-Instruct-v0.2",
device_configurations={DeviceConfigurations.N300x4},
device_configurations={DeviceConfigurations.N300x4_WH_ARCH_YAML},
docker_config=base_docker_config(),
user_uid=1000,
user_gid=1000,
Expand Down
6 changes: 3 additions & 3 deletions app/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ services:
- "8000:8000"
# command: bash
# dev server can be used for breakpoint debugging, does not support streaming
# command: ./manage.py runserver 0.0.0.0:8000
# command: python ./manage.py runserver 0.0.0.0:8000
# gunicorn is used from production, supports streaming

command: gunicorn --workers 3 --bind 0.0.0.0:8000 --preload --timeout 1200 api.wsgi:application
Expand Down Expand Up @@ -52,7 +52,7 @@ services:
# On first application load resources for transformers/etc
# are downloaded. The UI should not start until these resources
# have been downloaded. Adjust timeout if on a very slow connection
test: ["CMD", "curl", "-f", "http://localhost:8000/up/"]
test: [ "CMD", "curl", "-f", "http://localhost:8000/up/" ]
timeout: 120s
interval: 10s
retries: 5
Expand Down Expand Up @@ -93,7 +93,7 @@ services:
- "8111:8111"
healthcheck:
# Adjust below to match your container port
test: ["CMD", "curl", "-f", "http://localhost:8111/api/v1/heartbeat"]
test: [ "CMD", "curl", "-f", "http://localhost:8111/api/v1/heartbeat" ]
interval: 10s
timeout: 10s
retries: 3
Expand Down
Loading
Loading