Skip to content

Commit

Permalink
Merge pull request #92 from kleok/dev
Browse files Browse the repository at this point in the history
Optimize prediction funcs, add metadata to geojson files, plotting funcs
  • Loading branch information
kleok authored Oct 26, 2024
2 parents e4dd2db + 37fc1cf commit f7c51af
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 103 deletions.
24 changes: 12 additions & 12 deletions Floodpyapp_Vit.ipynb

Large diffs are not rendered by default.

146 changes: 85 additions & 61 deletions floodpy/Floodwater_delineation/Vit_approach/Predict_flooded_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import numpy as np
from torchvision import transforms
import xbatcher
import torch.nn.functional as F
from tqdm import tqdm
import sys

Expand Down Expand Up @@ -36,88 +36,112 @@ def predict_flooded_regions(Floodpy_app, ViT_model_filename, device):
vit_model = torch.load(ViT_model_filename)
vit_model.to(device)

batch_size = 224
starting_points = np.arange(0, batch_size, batch_size/4).astype(np.int32)
patch_size = 224
batch_size = 1
bands_out = 1
starting_points = np.arange(0, patch_size, patch_size/4).astype(np.int32)
data_mean = [0.0953, 0.0264]
data_std = [0.0427, 0.0215]
clamp_input = 0.15
Normalize = transforms.Normalize(mean=data_mean, std=data_std)

# load Sentinel-1 data
S1_dataset = xr.open_dataset(Floodpy_app.S1_stack_filename, decode_coords='all')
pre1_time, pre2_time, post_time = S1_dataset.time[-3:] # assumes last time is the recent (flooded) one

prediction_data_list = []
post = np.power(10, (np.stack([S1_dataset.isel(time=-1).VV_dB.values, S1_dataset.isel(time=-1).VH_dB.values],axis=0))/10)
pre1 = np.power(10, (np.stack([S1_dataset.isel(time=-2).VV_dB.values, S1_dataset.isel(time=-2).VH_dB.values],axis=0))/10)
pre2 = np.power(10, (np.stack([S1_dataset.isel(time=-3).VV_dB.values, S1_dataset.isel(time=-3).VH_dB.values],axis=0))/10)

for starting_point in starting_points:
print('Predictions with starting point: {} pixel'.format(starting_point))
xr_dataset = S1_dataset.sel(x=slice(S1_dataset.x.isel(x=starting_point).data,S1_dataset.x.isel(x=-1).data),
y=slice(S1_dataset.y.isel(y=starting_point).data,S1_dataset.y.isel(y=-1).data))

predictions_batches_list = []
post_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time = post_time), input_dims = {'x': batch_size, 'y': batch_size})
pre1_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time = pre1_time), input_dims = {'x': batch_size, 'y': batch_size})
pre2_bgen = xbatcher.BatchGenerator(xr_dataset.sel(time = pre2_time), input_dims = {'x': batch_size, 'y': batch_size})
num_patches = len(post_bgen)
post = torch.clamp(torch.from_numpy(post).float(), min=0.0, max=clamp_input)
post = torch.nan_to_num(post,clamp_input)
pre1 = torch.clamp(torch.from_numpy(pre1).float(), min=0.0, max=clamp_input)
pre1 = torch.nan_to_num(pre1,clamp_input)
pre2 = torch.clamp(torch.from_numpy(pre2).float(), min=0.0, max=clamp_input)
pre2 = torch.nan_to_num(pre2,clamp_input)

for patch_i in tqdm(range(num_patches)):
# Normalize input data
post_event_norm = Normalize(post)
pre_event_1_norm = Normalize(pre1)
pre_event_2_norm = Normalize(pre2)

post_dB = np.stack([post_bgen[patch_i].VV_dB.values, post_bgen[patch_i].VH_dB.values],axis=0)
post = np.power(10, post_dB/10) # convert to linear
# Concatenate input data as expected from the model (post,pre1,pre2)
input_data = torch.cat((post_event_norm, pre_event_1_norm, pre_event_2_norm), dim=0)

# Padding (left, right, top, bottom)
padding_size = (patch_size, patch_size, patch_size, patch_size)

pre1_dB = np.stack([pre1_bgen[patch_i].VV_dB.values, pre1_bgen[patch_i].VH_dB.values],axis=0)
pre1 = np.power(10, pre1_dB/10) # convert to linear
# Apply padding using F.pad
input_data_pad = F.pad(input_data, padding_size, mode='constant', value=0)

pre2_dB = np.stack([pre2_bgen[patch_i].VV_dB.values, pre2_bgen[patch_i].VH_dB.values],axis=0)
pre2 = np.power(10, pre2_dB/10) # convert to linear
[bands, rows, cols] = input_data_pad.shape

predictions_list = []

post = torch.clamp(torch.from_numpy(post).float(), min=0.0, max=clamp_input)
post = torch.nan_to_num(post,clamp_input)
pre1 = torch.clamp(torch.from_numpy(pre1).float(), min=0.0, max=clamp_input)
pre1 = torch.nan_to_num(pre1,clamp_input)
pre2 = torch.clamp(torch.from_numpy(pre2).float(), min=0.0, max=clamp_input)
pre2 = torch.nan_to_num(pre2,clamp_input)
for starting_point in starting_points:
print('Predictions with starting point: {} pixel'.format(starting_point))

with torch.cuda.amp.autocast(enabled=False):
with torch.no_grad():
post_event = Normalize(post).to(device).unsqueeze(0)
pre_event_1 = Normalize(pre1).to(device).unsqueeze(0)
pre_event_2 = Normalize(pre2).to(device).unsqueeze(0)
# Calculate the original number of patches along the width and height
num_patches_x = (cols - starting_point) // patch_size
num_patches_y = (rows - starting_point) // patch_size

pre_event_1 = pre_event_1.to(device)
post_event = torch.cat((post_event, pre_event_1), dim=1)
post_event = torch.cat((post_event, pre_event_2.to(device)), dim=1)
output = vit_model(post_event)
ending_point_x = num_patches_x*patch_size+starting_point
ending_point_y = num_patches_y*patch_size+starting_point

predictions = output.argmax(1)
# Select section from original image and add a batch dimension (1, B, H, W) since unfold expects a batched input
input_data_section = torch.tensor(input_data_pad[:,starting_point:ending_point_y, starting_point:ending_point_x]).unsqueeze(0)

# Use unfold to extract patches
# It extracts patches as columns of shape (B * patch_size * patch_size, L),
# where L is the number of patches.
patches = F.unfold(input_data_section, kernel_size=patch_size, stride=patch_size)

# Reshape the patches to (num_patches, B, patch_size, patch_size)
# The number of patches (num_patches) is (H // patch_size) * (W // patch_size)
num_patches = patches.size(-1)
patches = patches.permute(0, 2, 1) # (1, L, B * patch_size * patch_size)
patches = patches.reshape(1, num_patches, bands, patch_size, patch_size)

# Remove the batch dimension if not needed (optional)
patches = patches.squeeze(0)

# Now, patches contains all the patches with shape (bands, patch_size, patch_size)
# print(f"Patches shape: {patches.shape}")

num_patches = patches.shape[0]
with torch.cuda.amp.autocast(enabled=False):
with torch.no_grad():
predictions_patches= torch.zeros((num_patches, bands_out, patch_size, patch_size))
for i in tqdm(range(0,num_patches, batch_size)):
patches_torch=torch.tensor(patches[i:i+batch_size,:,:,:]).to(device)

output = vit_model(patches_torch).detach().cpu()

prediction_data = np.squeeze(predictions.to('cpu').numpy())

prediction_patch_xarray = xr.Dataset({'flood_vit': (["y","x"], prediction_data)},
coords={
"x": (["x"], post_bgen[patch_i].x.data),
"y": (["y"], post_bgen[patch_i].y.data),
},
)
prediction_patch_xarray.rio.write_crs("epsg:4326", inplace=True)
predictions_batches_list.append(prediction_patch_xarray)
# merging all patches
prediction_merged_batches = xr.combine_by_coords(predictions_batches_list)
# reindexing to have the same shape as the given dataset
prediction_merged_batches = prediction_merged_batches.reindex_like(S1_dataset, method=None)
# append to list
prediction_data_list.append(prediction_merged_batches.flood_vit.data)
predictions = output.argmax(1)
predictions_patches[i:i+batch_size,:,:,:]=predictions

# stacking all prediction and calculate the most common prediction value
prediction_data_array = np.stack(prediction_data_list)
prediction_data_median = np.nanmedian(prediction_data_array, axis=0)
del predictions, patches_torch, patches, output

save_to_netcdf(S1_dataset = S1_dataset,
prediction_data = prediction_data_median,
flooded_region_filename = Floodpy_app.Flood_map_dataset_filename,
flooded_regions_value = 2)
# Reshape the patches array back to the shape (num_patches_y, num_patches_x, bands, patch_size, patch_size)
predictions_patches = predictions_patches.reshape(num_patches_y, num_patches_x, bands_out, patch_size, patch_size)

# Transpose the axes back to (bands, num_patches_y * patch_size, num_patches_x * patch_size)
reconstructed_image = predictions_patches.permute(2, 0, 3, 1, 4).reshape(bands_out, num_patches_y * patch_size, num_patches_x * patch_size)

#print(f"Reconstructed image shape: {reconstructed_image.shape}")

reconstructed_image_full = torch.zeros((bands_out,rows,cols))
reconstructed_image_full[:,starting_point:ending_point_y, starting_point:ending_point_x] = reconstructed_image
predictions_list.append(reconstructed_image_full)
del reconstructed_image_full, reconstructed_image

# stacking all prediction and calculate the most common prediction value
prediction_data_array = np.stack(predictions_list)
del predictions_list
prediction_data_median = np.nanmedian(prediction_data_array, axis=0).squeeze()

prediction_data = prediction_data_median[patch_size:-patch_size, patch_size:-patch_size]

save_to_netcdf(S1_dataset = S1_dataset,
prediction_data = prediction_data,
flooded_region_filename = Floodpy_app.Flood_map_dataset_filename,
flooded_regions_value = 2)
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
<node id="Calibration">
<operator>Calibration</operator>
<sources>
<sourceProduct refid="Subset(2)"/>
<sourceProduct refid="Remove-GRD-Border-Noise"/>
</sources>
<parameters class="com.bc.ceres.binding.dom.XppDomElement">
<sourceBands/>
Expand Down Expand Up @@ -115,7 +115,7 @@
<node id="Calibration(2)">
<operator>Calibration</operator>
<sources>
<sourceProduct refid="Remove-GRD-Border-Noise(2)"/>
<sourceProduct refid="Subset(2)"/>
</sources>
<parameters class="com.bc.ceres.binding.dom.XppDomElement">
<sourceBands/>
Expand Down
12 changes: 1 addition & 11 deletions floodpy/Visualization/flood_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@

def plot_flooded_area_over_time(Floodpy_app, Floodpy_app_objs):

colorTones = {
6: '#CC3A5D', # dark pink
5: '#555555', # dark grey
4: '#A17C44', # dark brown
3: '#8751A1', # dark purple
2: '#C1403D', # dark red
1: '#2E5A87', # dark blue
0: '#57A35D', # dark green
}

Flooded_regions_areas_km2 = {}
for flood_date in Floodpy_app_objs.keys():
# calculate the area of flooded regions
Expand All @@ -38,9 +28,9 @@ def getcolor(val):

# Adjust the plot
plt.ylabel('Flooded area (km²)', fontsize=16)
plt.title('Flooded Area(km²) Over Time', fontsize=16)
plt.xticks(df['Datetime'].astype(str), df['Datetime'].dt.strftime('%d-%b-%Y'), rotation=30, ha='right', fontsize=16) # Set custom date format
plt.yticks(fontsize=16)
plt.grid()
plt.tight_layout() # Adjust layout for better fit

# Display the plot
Expand Down
78 changes: 78 additions & 0 deletions floodpy/utils/add_metadata_to_geojson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import geopandas as gpd
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import os


def add_metadata(Floodpy_app_objs, Floodpy_app, plot_flag = True):

distinctDarkTones = np.array([
'#264653', # dark teal/gray
'#2a9d8f', # deep green/teal
'#1d3557', # dark blue
'#4b5320', # army green
'#039BE5', # vivid blue
'#006400', # dark green
'#81D4FA', # sky blue
])

# choose the visualization colors
num_flood_events = len(Floodpy_app_objs)
color_indices = np.array([np.ceil(len(distinctDarkTones)*flood_ind/num_flood_events) for flood_ind in range(num_flood_events)], dtype=np.int32)
colors = distinctDarkTones[color_indices]

# calculate the pandas dataframe with flooded regions and add metadata (plot_color and max_entend)
Flooded_regions_df = pd.DataFrame()
for flood_date in Floodpy_app_objs.keys():
# calculate the area of flooded regions
Flood_map_vector_data = gpd.read_file(Floodpy_app_objs[flood_date].Flood_map_vector_dataset_filename)
Flood_map_vector_data_projected = Flood_map_vector_data.to_crs(Flood_map_vector_data.estimate_utm_crs())
area_km2 = round(Flood_map_vector_data_projected.area.sum()/1000000,2 )
flooded_region_temp = pd.DataFrame({'Flooded area (km2)':area_km2,
'geojson_filename':Floodpy_app_objs[flood_date].Flood_map_vector_dataset_filename}, index=[flood_date])
Flooded_regions_df = pd.concat([Flooded_regions_df,flooded_region_temp])

# Ascending sorting of flood events based on flooded area
Flooded_regions_df = Flooded_regions_df.sort_values(by=['Flooded area (km2)'])
Flooded_regions_df['plot_color'] = colors
Flooded_regions_df['max_extend'] = 'false'
max_extend_ind = Flooded_regions_df['Flooded area (km2)'].idxmax()
Flooded_regions_df.loc[max_extend_ind, ['max_extend']] = 'true'

# overwrite existing geojson files with metadata information

for index, row in Flooded_regions_df.iterrows():
with open(row['geojson_filename']) as f:
flooded_regions_json = json.load(f)

#Add top-level metadata (e.g., title, description, etc.)
flooded_regions_json['plot_color'] = row['plot_color']
flooded_regions_json['max_extend'] = row['max_extend']

#Save the modified GeoJSON with metadata to a file
with open(row['geojson_filename'], "w") as f:
json.dump(flooded_regions_json, f, indent=2)


if plot_flag:
Flooded_regions_df['Datetime'] = pd.to_datetime(Flooded_regions_df.index)

df = Flooded_regions_df.sort_index().copy()
# Plot the data
fig = plt.figure(figsize=(6, 5))
plt.bar(df['Datetime'].astype(str), df['Flooded area (km2)'], color='royalblue', width=0.7)

# Adjust the plot
plt.ylabel('Flooded area (km²)', fontsize=16)
plt.xticks(df['Datetime'].astype(str), df['Datetime'].dt.strftime('%d-%b-%Y'), rotation=30, ha='right', fontsize=16) # Set custom date format
plt.yticks(fontsize=16)
plt.grid()
plt.tight_layout() # Adjust layout for better fit

# Display the plot
fig_filename = os.path.join(Floodpy_app.Results_dir, '{}.svg'.format(Floodpy_app.flood_event))
plt.savefig(fig_filename,format="svg")
# plt.close()
print('The figure can be found at: {}'.format(fig_filename))
18 changes: 1 addition & 17 deletions floodpy/utils/geo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,8 @@
import datetime
import json

colorTones = {
6: '#CC3A5D', # dark pink
5: '#555555', # dark grey
4: '#A17C44', # dark brown
3: '#8751A1', # dark purple
2: '#C1403D', # dark red
1: '#2E5A87', # dark blue
0: '#57A35D', # dark green
}


def create_polygon(coordinates):
return Polygon(coordinates['coordinates'][0])

return Polygon(np.array(coordinates['coordinates']).squeeze())

def convert_to_vector(Floodpy_app):
with rasterio.open(Floodpy_app.Flood_map_dataset_filename) as src:
Expand All @@ -42,16 +30,12 @@ def convert_to_vector(Floodpy_app):
geojson_str = gdf.to_json() # This gives the GeoJSON as a string
geojson_dict = json.loads(geojson_str) # Convert the string to a dictionary

# find the color of plotting
color_ind = Floodpy_app.flood_datetimes.index(Floodpy_app.flood_datetime)
plot_color = colorTones[color_ind]
#Add top-level metadata (e.g., title, description, etc.)
geojson_dict['flood_event'] = Floodpy_app.flood_event
geojson_dict['description'] = "This GeoJSON contains polygons of flooded regions using Sentinel-1 data."
geojson_dict['produced_by'] = "Floodpy"
geojson_dict['creation_date_UTC'] = datetime.datetime.now(datetime.timezone.utc).strftime('%Y%m%dT%H%M%S')
geojson_dict['flood_datetime_UTC'] = Floodpy_app.flood_datetime_str
geojson_dict['plot_color'] = plot_color
geojson_dict['bbox'] = Floodpy_app.bbox

#Save the modified GeoJSON with metadata to a file
Expand Down

0 comments on commit f7c51af

Please sign in to comment.