-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
32 lines (31 loc) · 771 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import segmentation_models_pytorch as smp
epochs = 25
train_batch_size = 32
val_batch_size = 32
device = "cuda:0" if torch.cuda.is_available() else "cpu"
bands = ["B02", "B03", "B04", "B08"]
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam
model = smp.DeepLabV3Plus(
encoder_name="resnet101",
in_channels=4,
classes=2,
)
scaler = torch.cuda.amp.GradScaler()
learning_rate = 3e-4
train_transforms = A.Compose(
[
A.Rotate(limit=60, p=0.6),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
ToTensorV2(),
],
)
val_transforms = A.Compose(
[
ToTensorV2(),
]
)