-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import csv | ||
import numpy as np | ||
from ultralytics import YOLO | ||
import json | ||
import torch | ||
from torch.utils.data import Dataset | ||
from pycocotools import mask as mask_util | ||
|
||
# RLE 인코딩 함수 | ||
def encode_mask_to_rle(mask): | ||
''' | ||
mask: numpy array binary mask | ||
1 - mask | ||
0 - background | ||
Returns encoded run length | ||
''' | ||
pixels = mask.flatten() | ||
pixels = np.concatenate([[0], pixels, [0]]) | ||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 | ||
runs[1::2] -= runs[::2] | ||
return ' '.join(str(x) for x in runs) | ||
|
||
# RLE 디코딩 함수 | ||
def decode_rle_to_mask(rle, height, width): | ||
s = rle.split() | ||
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] | ||
starts -= 1 | ||
ends = starts + lengths | ||
img = np.zeros(height * width, dtype=np.uint8) | ||
|
||
for lo, hi in zip(starts, ends): | ||
img[lo:hi] = 1 | ||
|
||
return img.reshape(height, width) | ||
|
||
# 모델 로드 | ||
model = YOLO("/data/ephemeral/home/jiwan/Gyeonggi-Autonomous-Driving-Center-Data-Utilization-Competition/yolo/runs/segment/train2/weights/best.pt") | ||
|
||
# 예측할 이미지 폴더 경로 설정 | ||
image_folder = "/data/ephemeral/home/dataset_yolo/test" | ||
|
||
# CSV 파일 생성 및 헤더 작성 | ||
csv_file_path = "/data/ephemeral/home/jiwan/yolo/result/predictions.csv" | ||
with open(csv_file_path, mode='w', newline='') as csv_file: | ||
writer = csv.writer(csv_file) | ||
writer.writerow(['image_name', 'class', 'rle']) | ||
|
||
# 이미지 폴더 내 모든 이미지 파일 리스트 가져오기 (이미지 이름 순서 유지) | ||
image_files = sorted([os.path.join(image_folder, file) for file in os.listdir(image_folder) if file.endswith(('.png', '.jpg', '.jpeg'))]) | ||
|
||
# 클래스 이름 정렬 (미리 정의된 클래스 이름 리스트 사용 후 정렬) | ||
class_names = [ | ||
"finger-1", "finger-2", "finger-3", "finger-4", "finger-5", | ||
"finger-6", "finger-7", "finger-8", "finger-9", "finger-10", | ||
"finger-11", "finger-12", "finger-13", "finger-14", "finger-15", | ||
"finger-16", "finger-17", "finger-18", "finger-19", | ||
"Trapezium", "Trapezoid", "Capitate", "Hamate", | ||
"Scaphoid", "Lunate", "Pisiform", "Triquetrum", | ||
"Radius", "Ulna" | ||
] | ||
|
||
# 각 이미지에 대해 예측 수행 후 저장 | ||
for image_path in image_files: | ||
# 예측 수행 | ||
results = model(image_path) | ||
|
||
# 각 결과에 대해 처리 | ||
for result in results: | ||
# CSV 파일에 결과 저장 | ||
with open(csv_file_path, mode='a', newline='') as csv_file: | ||
writer = csv.writer(csv_file) | ||
|
||
# 마스크 결과 가져오기 | ||
if result.masks is not None: | ||
# 박스와 마스크 데이터를 클래스 인덱스 순서대로 정렬 | ||
sorted_results = sorted(zip(result.boxes, result.masks.data), key=lambda x: int(x[0].cls)) | ||
|
||
for box, mask in sorted_results: | ||
# 클래스 가져오기 | ||
class_idx = int(box.cls) | ||
class_name = class_names[class_idx] | ||
|
||
# 마스크를 numpy 배열로 변환하고 RLE로 인코딩 | ||
mask_np = mask.cpu().numpy().astype(np.uint8) # dtype을 uint8로 변환 | ||
# 각 픽셀 값이 0 또는 1인지 확인 (이진화) | ||
mask_np = (mask_np > 0).astype(np.uint8) | ||
rle_str = encode_mask_to_rle(mask_np) # 커스텀 RLE 인코딩 사용 | ||
|
||
# CSV 파일에 기록 | ||
writer.writerow([os.path.basename(image_path), class_name, rle_str]) | ||
|
||
print(f"Predictions for {image_path} logged in {csv_file_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from ultralytics import YOLO | ||
import torch | ||
|
||
# 모델 설정 정보 | ||
yaml_path = "yolo11x-seg.yaml" # 새로운 모델 설정을 위한 YAML 파일 | ||
pretrained_weights = "yolo11x-seg.pt" # 사전 훈련된 가중치 파일 경로 | ||
transfer_weights = "yolo11x.pt" # YAML 빌드 시 사용할 가중치 파일 경로 | ||
|
||
# 모델 빌드 또는 가중치 로드 | ||
model = YOLO(yaml_path) # YAML 파일로 새 모델 빌드 | ||
model = YOLO(pretrained_weights) # 사전 훈련된 모델 로드 | ||
model = YOLO(yaml_path).load(transfer_weights) # YAML로 빌드하고 가중치 로드 | ||
|
||
# 모델 훈련 설정 | ||
data_path = "/data/ephemeral/home/jiwan/yolo/data.yaml" # 데이터 설정 파일 경로 | ||
epochs = 250 | ||
imgsz = 1280 | ||
batch_size = 4 # 추가한 batch_size 설정 | ||
|
||
# GPU 메모리 정리 | ||
torch.cuda.empty_cache() | ||
|
||
# 모델 훈련 시작 | ||
results = model.train( | ||
data=data_path, | ||
epochs=epochs, | ||
imgsz=imgsz, | ||
batch=batch_size | ||
) | ||
|
||
print("훈련이 완료되었습니다.") |