Skip to content

Commit

Permalink
Merge pull request #35 from thibaulttabarin/main
Browse files Browse the repository at this point in the history
Correct the alignment of the fish
  • Loading branch information
thibaulttabarin authored Aug 3, 2022
2 parents 97e0382 + fce5582 commit 45052ee
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 82 deletions.
19 changes: 11 additions & 8 deletions Scripts/Morphology_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,18 @@ def main(input_file, metadata_file, output_measure, output_landmark, output_pres
output_lm_image=None):

# Create the image segmentation object
img_seg = tc.segmented_image(input_file)
img_seg = tc.Segmented_image(input_file, align = True)
base_name = img_seg.base_name
# Create object measure_morphology
measure_morph = tc.Measure_morphology(input_file, align = True)
# Calcualte the mesaurements and landmarks
img_seg.get_all_measures_landmarks()

# Assign variables
measurements_bbox = img_seg.measurement_with_bbox
measurements_lm = img_seg.measurement_with_lm
measurements_area = img_seg.measurement_with_area
landmark = img_seg.landmark
presence_matrix = img_seg.presence_matrix
measurements_bbox = measure_morph.measurement_with_bbox
measurements_lm = measure_morph.measurement_with_lm
measurements_area = measure_morph.measurement_with_area
landmark = measure_morph.landmark
presence_matrix = measure_morph.presence_matrix

# Combine the 3 types of measurements (lm, bbox, area) and reorder the keys
measurement = {'base_name': base_name, **measurements_bbox, **measurements_lm, **measurements_area }
Expand Down Expand Up @@ -108,7 +109,9 @@ def main(input_file, metadata_file, output_measure, output_landmark, output_pres

if output_lm_image:

img_landmark = img_seg.visualize_landmark()

# create landmark visualization image and save it
img_landmark = measure_morph.visualize_landmark()
img_landmark.save(output_lm_image)


Expand Down
166 changes: 92 additions & 74 deletions Scripts/Traits_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,31 @@
from skimage.morphology import reconstruction


class segmented_image:
class Segmented_image:

def __init__(self, file_name):
def __init__(self, file_name, align = True, cutoff = 0.6):
self.file = file_name
self.image_name = os.path.split(file_name)[1]
self.base_name = self.image_name.rsplit('_',1)[0]
self.cutoff = cutoff
self.align = align

self.trait_color_dict={'background': [0, 0, 0],'dorsal_fin': [254, 0, 0],'adipos_fin': [0, 254, 0],
'caudal_fin': [0, 0, 254],'anal_fin': [254, 254, 0],'pelvic_fin': [0, 254, 254],
'pectoral_fin': [254, 0, 254],'head': [254, 254, 254],'eye': [0, 254, 102],
'caudal_fin_ray': [254, 102, 102],'alt_fin_ray': [254, 102, 204],
'trunk': [0, 124, 124]}
self.cutoff = 0.60

self.img_arr = self.import_image(file_name)
self.fish_angle = self.get_fish_angle_pca()
if align:
self.img_arr = self.align_fish()
self.old_fish_angle = self.fish_angle
self.fish_angle = self.get_fish_angle_pca()

self.get_channels_mask()
self.presence_matrix = self.get_presence_matrix()
self.fish_angle = self.get_fish_angle_pca()

def get_all_measures_landmarks(self):
'''
Execute the multiple functions that calculate landmarks and measurements
'''
self.landmark = self.all_landmark()
self.measurement_with_bbox = self.all_measure_using_bbox()
self.measurement_with_lm = self.all_measure_using_lm()
self.measurement_with_area = self.all_measure_area()



def import_image(self,file_name):
'''
Import the image from "image_path" and convert to np.array astype uint8 (0-255)
Expand All @@ -50,7 +48,43 @@ def import_image(self,file_name):
img_arr = np.array(img, dtype=np.uint8)

return img_arr


def get_fish_angle_pca(self):
'''
Calculate orientation (PCA) of the mask of whole fish
We choose to combine whole fish part and calculate orientation.
return value in degree
'''

# create a mask with all the fish traits
img_arr = self.img_arr
whole_fish = np.sum(img_arr,axis=2).astype(bool)

# Clean holes and remove isolated blobs and create a regionprop
trait_region = self.clean_trait_region(whole_fish)
angle_rad = trait_region.orientation
#fish_angle = (90-angle_rad*180/math.pi)
fish_angle = np.sign(angle_rad) * (90-abs(angle_rad*180/math.pi))

# + 0.0 remove negative sign on rounded 0.0 value
return round(fish_angle,2) + 0.0


def align_fish(self):
'''
Development
To align the fish horizontally
in order to get landmark 5 and 6
'''

img_arr = self.img_arr
angle_deg = self.fish_angle

image_align = Image.fromarray(img_arr).rotate(angle_deg)

return np.array(image_align, dtype=np.uint8)


def get_channels_mask(self):
'''
Convert the png image (numpy.ndarray, np.uint8) (320, 800, 3)
Expand All @@ -75,41 +109,6 @@ def get_channels_mask(self):

self.mask = mask

def get_fish_angle_pca(self):
'''
Calculate orientation (PCA) of the mask of whole fish
We choose to combine whole fish part and calculate orientation.
return value in degree
'''
fish_angle = "None"
mask = self.mask
whole_fish = np.zeros_like(None,dtype="uint8")
for i,(k,v) in enumerate(mask.items()):
whole_fish = whole_fish +v

# Check that the mask is not empty
if np.any(whole_fish):

trait_region = self.clean_trait_region(whole_fish)
angle_rad = trait_region.orientation
fish_angle = (90-angle_rad*180/math.pi)

return round(fish_angle,2)

def align_fish(self):
'''
Development
To align the fish horizontally
in order to get landmark 5 and 6
'''

img_arr = self.img_arr
angle_deg = self.fish_angle

image_align = Image.fromarray(img_arr).rotate(angle_deg)

return image_align

def remove_holes(self, image):

seed = np.copy(image)
Expand Down Expand Up @@ -171,27 +170,23 @@ def get_presence_matrix(self):
presence_matrix[trait_name] = temp_dict

return presence_matrix


class Measure_morphology(Segmented_image):

def __init__(self, file_name, align=True):

super().__init__(file_name, align=True)
self.get_all_measures_landmarks()

def get_one_property_all_trait(self, property_='centroid'):
def get_all_measures_landmarks(self):
'''
Create a dictionnary with key = trait and value the property selected by property_
example: {'dorsal_fin': (centroid[0], centroid[1]), 'trunk':(centroid[0], centroid[1])....}
Execute the multiple functions that calculate landmarks and measurements
'''

mask = self.mask
dict_property={}
# enumerate through the dictionary of mask to collect trait_name
for i, (trait_name, trait_mask) in enumerate(mask.items()):

trait_region = self.clean_trait_region(trait_mask)

if trait_region:
dict_property[trait_name] = trait_region[property_]
else:
dict_property[trait_name]=None

return dict_property
self.landmark = self.all_landmark()
self.measurement_with_bbox = self.all_measure_using_bbox()
self.measurement_with_lm = self.all_measure_using_lm()
self.measurement_with_area = self.all_measure_area()

def get_distance(self, a,b):
'''
Expand Down Expand Up @@ -565,11 +560,7 @@ def all_measure_area(self):
measure_area['HA_m'] = self.measure_head_area()

return measure_area






########################
# Measurement using bbox
########################
Expand Down Expand Up @@ -628,7 +619,7 @@ def all_measure_using_bbox(self):
'''
Collect the measurment for the fish for Meghan paper
'''
measures_bbox={'SL_bbox':'None', 'HL_bbox':'None', 'ED_bbox':'None','pOD_bbox':'None', 'fish_angle':'None' }
measures_bbox={'SL_bbox':'None', 'HL_bbox':'None', 'ED_bbox':'None','pOD_bbox':'None', 'FA_pca':'None' }

# SL standart length, length bbox of head+trunk

Expand All @@ -647,9 +638,36 @@ def all_measure_using_bbox(self):

return measures_bbox

def visualize_landmark(self):

landmark = self.all_landmark()
img_arr = self.img_arr
img = Image.fromarray(img_arr)
img1 = ImageDraw.Draw(img)

#
#fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 15)
fnt = ImageFont.load_default()
for i,(k,v) in enumerate(landmark.items()):

# landmark exist draw it on the image
if v:
row,col = v
xy = [(col-9,row-9),(col+9,row+9)]
img1.ellipse(xy, fill='gray', outline=None, width=1)

img1.text((col-6, row-6), k, font=fnt, fill='black')
# Display the image created
return img

class Visualization_morphology(Measure_morphology):
############################
# Visualization function
############################
def __init__(self, file_name, align=True):

super().__init__(file_name, align=True)

def visualize_trait(self, trait):

mask = self.mask
Expand Down

0 comments on commit 45052ee

Please sign in to comment.