Skip to content

Commit

Permalink
add support for complex64 tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jan 25, 2021
1 parent 6ed32b0 commit d54e793
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 103 deletions.
164 changes: 164 additions & 0 deletions Example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets, transforms\n",
"from complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear, NaiveComplexBatchNorm2d\n",
"from complexFunctions import complex_relu, complex_max_pool2d"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 64\n",
"trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n",
"train_set = datasets.MNIST('../data', train=True, transform=trans, download=True)\n",
"test_set = datasets.MNIST('../data', train=False, transform=trans, download=True)\n",
"\n",
"train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class ComplexNet(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(ComplexNet, self).__init__()\n",
" self.conv1 = ComplexConv2d(1, 10, 5, 1)\n",
" self.bn = ComplexBatchNorm2d(10)\n",
" self.conv2 = ComplexConv2d(10, 20, 5, 1)\n",
" self.fc1 = ComplexLinear(4*4*20, 500)\n",
" self.fc2 = ComplexLinear(500, 10)\n",
" \n",
" def forward(self,x):\n",
" x = self.conv1(x)\n",
" x = complex_relu(x)\n",
" x = complex_max_pool2d(x, 2, 2)\n",
" x = self.bn(x)\n",
" x = self.conv2(x)\n",
" x = complex_relu(x)\n",
" x = complex_max_pool2d(x, 2, 2)\n",
" x = x.view(-1,4*4*20)\n",
" x = self.fc1(x)\n",
" x = complex_relu(x)\n",
" x = self.fc2(x)\n",
" x = x.abs()\n",
" x = F.log_softmax(x, dim=1)\n",
" return x\n",
" \n",
"device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model = ComplexNet().to(device)\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
"\n",
"def train(model, device, train_loader, optimizer, epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" data, target =data.to(device).type(torch.complex64), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" if batch_idx % 100 == 0:\n",
" print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch,\n",
" batch_idx * len(data), \n",
" len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), \n",
" loss.item())\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 0 [ 0/ 60000 ( 0%)]\tLoss: 2.349018\n",
"Train Epoch: 0 [ 6400/ 60000 ( 11%)]\tLoss: 0.252006\n",
"Train Epoch: 0 [ 12800/ 60000 ( 21%)]\tLoss: 0.094634\n",
"Train Epoch: 0 [ 19200/ 60000 ( 32%)]\tLoss: 0.096171\n",
"Train Epoch: 0 [ 25600/ 60000 ( 43%)]\tLoss: 0.039067\n",
"Train Epoch: 0 [ 32000/ 60000 ( 53%)]\tLoss: 0.062306\n",
"Train Epoch: 0 [ 38400/ 60000 ( 64%)]\tLoss: 0.091644\n",
"Train Epoch: 0 [ 44800/ 60000 ( 75%)]\tLoss: 0.154324\n",
"Train Epoch: 0 [ 51200/ 60000 ( 85%)]\tLoss: 0.015835\n",
"Train Epoch: 0 [ 57600/ 60000 ( 96%)]\tLoss: 0.005899\n",
"Train Epoch: 1 [ 0/ 60000 ( 0%)]\tLoss: 0.013530\n",
"Train Epoch: 1 [ 6400/ 60000 ( 11%)]\tLoss: 0.031689\n",
"Train Epoch: 1 [ 12800/ 60000 ( 21%)]\tLoss: 0.025631\n",
"Train Epoch: 1 [ 19200/ 60000 ( 32%)]\tLoss: 0.031679\n",
"Train Epoch: 1 [ 25600/ 60000 ( 43%)]\tLoss: 0.021937\n",
"Train Epoch: 1 [ 32000/ 60000 ( 53%)]\tLoss: 0.095149\n",
"Train Epoch: 1 [ 38400/ 60000 ( 64%)]\tLoss: 0.008647\n",
"Train Epoch: 1 [ 44800/ 60000 ( 75%)]\tLoss: 0.088300\n",
"Train Epoch: 1 [ 51200/ 60000 ( 85%)]\tLoss: 0.003999\n",
"Train Epoch: 1 [ 57600/ 60000 ( 96%)]\tLoss: 0.004459\n",
"Train Epoch: 2 [ 0/ 60000 ( 0%)]\tLoss: 0.003121\n",
"Train Epoch: 2 [ 6400/ 60000 ( 11%)]\tLoss: 0.003100\n",
"Train Epoch: 2 [ 12800/ 60000 ( 21%)]\tLoss: 0.001305\n",
"Train Epoch: 2 [ 19200/ 60000 ( 32%)]\tLoss: 0.017995\n"
]
}
],
"source": [
"# Run training on 4 epochs\n",
"for epoch in range(4):\n",
" train(model, device, train_loader, optimizer, epoch)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}
61 changes: 27 additions & 34 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

A high-level toolbox for using complex valued neural networks in PyTorch.

Before version 1.7 of PyTroch, complex tensor were not supported.
The initial version of **complexPyTorch** represented complex tensor using two tensors, one for the real and one for the imaginary part.
Since version 1.7, compex tensors of type `torch.complex64` are allowed, but only a limited number of operation are supported.
The current version **complexPyTorch** use complex tensors (hence requires PyTorch version >= 1.7) and add support for various operations and layers.

## Complex Valued Networks with PyTorch

Artificial neural networks are mainly used for treating data encoded in real values, such as digitized images or sounds.
Expand All @@ -17,7 +22,6 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat
* BatchNorm2d (Naive and Covariance approach)



## Syntax and usage

The syntax is supposed to copy the one of the standard real functions and modules from PyTorch.
Expand Down Expand Up @@ -58,49 +62,42 @@ class ComplexNet(nn.Module):

def __init__(self):
super(ComplexNet, self).__init__()
self.conv1 = ComplexConv2d(1, 20, 5, 1)
self.bn = ComplexBatchNorm2d(20)
self.conv2 = ComplexConv2d(20, 50, 5, 1)
self.fc1 = ComplexLinear(4*4*50, 500)
self.conv1 = ComplexConv2d(1, 10, 5, 1)
self.bn = ComplexBatchNorm2d(10)
self.conv2 = ComplexConv2d(10, 20, 5, 1)
self.fc1 = ComplexLinear(4*4*20, 500)
self.fc2 = ComplexLinear(500, 10)

def forward(self,x):
xr = x
# imaginary part to zero
xi = torch.zeros(xr.shape, dtype = xr.dtype, device = xr.device)
xr,xi = self.conv1(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = complex_max_pool2d(xr,xi, 2, 2)


xr,xi = self.bn(xr,xi)
xr,xi = self.conv2(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = complex_max_pool2d(xr,xi, 2, 2)

xr = xr.view(-1, 4*4*50)
xi = xi.view(-1, 4*4*50)
xr,xi = self.fc1(xr,xi)
xr,xi = complex_relu(xr,xi)
xr,xi = self.fc2(xr,xi)
# take the absolute value as output
x = torch.sqrt(torch.pow(xr,2)+torch.pow(xi,2))
return F.log_softmax(x, dim=1)
x = self.conv1(x)
x = complex_relu(x)
x = complex_max_pool2d(x, 2, 2)
x = self.bn(x)
x = self.conv2(x)
x = complex_relu(x)
x = complex_max_pool2d(x, 2, 2)
x = x.view(-1,4*4*20)
x = self.fc1(x)
x = complex_relu(x)
x = self.fc2(x)
x = x.abs()
x = F.log_softmax(x, dim=1)
return x

device = torch.device("cuda:0" )
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ComplexNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
data, target = data.to(device).type(torch.complex64), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 1000 == 0:
if batch_idx % 100 == 0:
print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(data),
Expand All @@ -113,11 +110,7 @@ def train(model, device, train_loader, optimizer, epoch):
for epoch in range(50):
train(model, device, train_loader, optimizer, epoch)
```
## Todo
* Script ComplexBatchNorm for improved efficiency ([jit doc](https://pytorch.org/docs/stable/jit.html))
* Add more layers (Conv1D, Upsample, ConvTranspose...)
* Add complex cost functions and usual functions (e.g. Pearson correlation)

## Acknowledgments

Expand Down
23 changes: 12 additions & 11 deletions complexFunctions.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
"""

from torch.nn.functional import relu, max_pool2d, dropout, dropout2d
import torch

def complex_relu(input_r,input_i):
return relu(input_r), relu(input_i)
def complex_relu(input):
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)

def complex_max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0,
def complex_max_pool2d(input,kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):

return max_pool2d(input_r, kernel_size, stride, padding, dilation,
ceil_mode, return_indices), \
max_pool2d(input_i, kernel_size, stride, padding, dilation,
ceil_mode, return_indices)
return max_pool2d(input.real, kernel_size, stride, padding, dilation,
ceil_mode, return_indices).type(torch.complex64) \
+ 1j*max_pool2d(input.imag, kernel_size, stride, padding, dilation,
ceil_mode, return_indices).type(torch.complex64)

def complex_dropout(input_r,input_i, p=0.5, training=True, inplace=False):
return dropout(input_r, p, training, inplace), \
dropout(input_i, p, training, inplace)
return dropout(input_r, p, training, inplace).type(torch.complex64) \
+1j*dropout(input_i, p, training, inplace).type(torch.complex64)


def complex_dropout2d(input_r,input_i, p=0.5, training=True, inplace=False):
return dropout2d(input_r, p, training, inplace), \
dropout2d(input_i, p, training, inplace)
return dropout2d(input_r, p, training, inplace).type(torch.complex64) \
+1j*dropout2d(input_i, p, training, inplace).type(torch.complex64)
Loading

0 comments on commit d54e793

Please sign in to comment.