diff --git a/.githooks/pre-commit b/.githooks/pre-commit
new file mode 100755
index 00000000..f3a5aa98
--- /dev/null
+++ b/.githooks/pre-commit
@@ -0,0 +1,69 @@
+#!/bin/bash
+#
+# A hook script to verify what is about to be committed.
+# Called by "git commit" with no arguments. The hook should
+# exit with non-zero status after issuing an appropriate message if
+# it wants to stop the commit.
+
+# Fail immediately at first issue with the relevant exit status.
+set -eo pipefail
+
+# ===================================================================
+
+if git rev-parse --verify HEAD >/dev/null 2>&1
+then
+ against=HEAD
+else
+ # Initial commit: diff against an empty tree object
+ against=$(git hash-object -t tree /dev/null)
+fi
+
+# ===================================================================
+
+# Check that ftorch.90 is not modified and staged alone.
+git diff --cached --name-only | if grep --quiet "ftorch.f90"; then
+ git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then
+ cat <<\EOF
+Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed.
+ftorch.90 should be generated from ftorch.fypp using fypp.
+Please restore ftorch.f90 and make your modifications to ftorch.fypp instead.
+EOF
+ exit 1
+ fi
+fi
+
+# Check to see if ftorch.fypp has been modified AND is staged.
+git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then
+
+ # Check that ftorch.90 is also modified and staged.
+ git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then
+ cat <<\EOF
+Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not.
+ftorch.90 should be generated from ftorch.fypp and both committed together.
+Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together.
+EOF
+ exit 1
+ else
+ # Check fypp available, and raise error and exit if not.
+ if ! command -v fypp &> /dev/null; then
+ cat <<\EOF
+echo "Error: Could not find fypp to run on ftorch.fypp.
+Please install fypp using "pip install fypp" and then try committing again.
+EOF
+ exit 1
+ fi
+
+ # If fypp is available and both .f90 and .fypp staged, check they match.
+ fypp src/ftorch.fypp src/ftorch.f90_tmp
+ if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then
+ rm src/ftorch.f90_tmp
+ cat <<\EOF
+Error: The code in ftorch.f90 does not match that expected from ftorch.fypp.
+Please re-run fypp on ftorch.fypp to ensure consistency before committing.
+EOF
+ exit 1
+ else
+ rm src/ftorch.f90_tmp
+ fi
+ fi
+fi
diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml
new file mode 100644
index 00000000..88f5af8c
--- /dev/null
+++ b/.github/workflows/fypp.yml
@@ -0,0 +1,27 @@
+name: fypp-checks
+
+on:
+ # run on every push
+ push:
+
+jobs:
+ various:
+ name: FYPP checks - runs check on fypp and f90 files
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.11"
+ - run: pip install fypp
+
+ - name: Check fypp matches f90
+ run: |
+ fypp src/ftorch.fypp src/temp.f90_temp
+ if ! diff -q src/ftorch.f90 src/temp.f90_temp; then
+ echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp."
+ echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit."
+ exit 1
+ else
+ exit 0
+ fi
diff --git a/.github/workflows/python_qc.yaml b/.github/workflows/python_qc.yaml
new file mode 100644
index 00000000..bcaf7ac5
--- /dev/null
+++ b/.github/workflows/python_qc.yaml
@@ -0,0 +1,32 @@
+name: python-code-qc
+
+on:
+ # run on every push to main
+ push:
+ branches:
+ - main
+ # run on every push (not commit) to a PR, plus open/reopen
+ pull_request:
+ types:
+ - synchronize
+ - opened
+ - reopened
+
+jobs:
+ various:
+ name: Python Code QC (Black, pydocstyle)
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.11"
+ - run: pip install black pydocstyle
+
+ # annotate each step with `if: always` to run all regardless
+ - name: Assert that code matches Black code style
+ if: always()
+ uses: psf/black@stable
+ - name: Lint with pydocstyle
+ if: always()
+ run: pydocstyle --convention=numpy ./
diff --git a/README.md b/README.md
index 2019a51e..a02120d9 100644
--- a/README.md
+++ b/README.md
@@ -79,12 +79,21 @@ To build and install the library:
| [`CMAKE_PREFIX_PATH`](https://cmake.org/cmake/help/latest/variable/CMAKE_PREFIX_PATH.html) | `` | Location of Torch installation1 |
| [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `` | Location at which the library files should be installed. By default this is `/usr/local` |
| [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code|
+ | `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA2 |
1 _The path to the Torch installation needs to allow cmake to locate the relevant Torch cmake files.
If Torch has been [installed as libtorch](https://pytorch.org/cppdocs/installing.html)
then this should be the absolute path to the unzipped libtorch distribution.
If Torch has been installed as PyTorch in a python [venv (virtual environment)](https://docs.python.org/3/library/venv.html),
e.g. with `pip install torch`, then this should be `lib/python<3.xx>/site-packages/torch/`._
+
+ 2 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g.
+ `pip install torch --index-url https://download.pytorch.org/whl/cpu`
+ or
+ `pip install torch --index-url https://download.pytorch.org/whl/cu118`
+ (for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._
+
+
4. Make and install the code to the chosen location with:
```
make
@@ -92,7 +101,9 @@ To build and install the library:
```
This will place the following directories at the install location:
* `CMAKE_INSTALL_PREFIX/include/` - contains header and mod files
- * `CMAKE_INSTALL_PREFIX/lib64/` - contains `cmake` directory and `.so` files
+ * `CMAKE_INSTALL_PREFIX/lib/` - contains `cmake` directory and `.so` files
+ _Note: depending on your system and architecture `lib` may be `lib64`, and
+ you may have `.dll` files or similar._
## Usage
@@ -119,11 +130,10 @@ To use the trained Torch model from within Fortran we need to import the `ftorch
A very simple example is given below.
For more detailed documentation please consult the API documentation, source code, and examples.
-This minimal snippet loads a saved Torch model, creates inputs consisting of two `10x10` matrices (one of ones, and one of zeros), and runs the model to infer output.
+This minimal snippet loads a saved Torch model, creates an input consisting of a `10x10` matrix of ones, and runs the model to infer output.
+This is illustrative only, and we recommend following the [examples](examples/) before writing your own code to explore more features.
```fortran
-! Import any C bindings as required for this code
-use, intrinsic :: iso_c_binding, only: c_int, c_int64_t, c_loc
! Import library for interfacing with PyTorch
use ftorch
@@ -132,34 +142,32 @@ implicit none
! Generate an object to hold the Torch model
type(torch_module) :: model
-! Set up types of input and output data and the interface with C
-integer(c_int), parameter :: dims_input = 2
-integer(c_int64_t) :: shape_input(dims_input)
-integer(c_int), parameter :: n_inputs = 2
+! Set up types of input and output data
+integer, parameter :: n_inputs = 1
type(torch_tensor), dimension(n_inputs) :: model_input_arr
-integer(c_int), parameter :: dims_output = 1
-integer(c_int64_t) :: shape_output(dims_output)
type(torch_tensor) :: model_output
-! Set up the model inputs as Fortran arrays
-real, dimension(10,10), target :: input_1, input_2
+! Set up the model input and output as Fortran arrays
+real, dimension(10,10), target :: input
real, dimension(5), target :: output
+! Set up number of dimensions of input tensor and axis order
+integer, parameter :: in_dims = 2
+integer :: in_layout(in_dims) = [1,2]
+integer, parameter :: out_dims = 1
+integer :: out_layout(out_dims) = [1]
+
! Initialise the Torch model to be used
model = torch_module_load("/path/to/saved/model.pt")
-! Initialise the inputs as Fortran
-input_1 = 0.0
-input_2 = 1.0
+! Initialise the inputs as Fortran array of ones
+input = 1.0
! Wrap Fortran data as no-copy Torch Tensors
! There may well be some reshaping required depending on the
! structure of the model which is not covered here (see examples)
-shape_input = (/10, 10/)
-shape_output = (/5/)
-model_input_arr(1) = torch_tensor_from_blob(c_loc(input_1), dims_input, shape_input, torch_kFloat64, torch_kCPU)
-model_input_arr(2) = torch_tensor_from_blob(c_loc(input_2), dims_input, shape_input, torch_kFloat64, torch_kCPU)
-model_output = torch_tensor_from_blob(c_loc(output), dims_output, shape_output, torch_kFloat64, torch_kCPU)
+model_input_arr(1) = torch_tensor_from_array(input, in_layout, torch_kCPU)
+model_output = torch_tensor_from_array(output, out_layout, torch_kCPU)
! Run model and Infer
! Again, there may be some reshaping required depending on model design
@@ -171,7 +179,6 @@ write(*,*) output
! Clean up
call torch_module_delete(model)
call torch_tensor_delete(model_input_arr(1))
-call torch_tensor_delete(model_input_arr(2))
call torch_tensor_delete(model_output)
```
@@ -189,7 +196,10 @@ find_package(FTorch)
target_link_libraries( PRIVATE FTorch::ftorch )
message(STATUS "Building with Fortran PyTorch coupling")
```
-and using the `-DFTorch_DIR=` flag when running cmake.
+and using the `-DCMAKE_PREFIX_PATH=` flag when running cmake.
+_Note: If you used the `CMAKE_INSTALL_PREFIX` argument when
+[building and installing the library](#library-installation) above then you should use
+the same path for ``._
#### Make
To build with make we need to include the library when compiling and link the executable
@@ -203,14 +213,35 @@ FCFLAGS += -I/include/ftorch
When compiling the final executable add the following link flag:
```
-LDFLAGS += -L/lib64 -lftorch
+LDFLAGS += -L/lib -lftorch
```
You may also need to add the location of the `.so` files to your `LD_LIBRARY_PATH`
unless installing in a default location:
```
-export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:/lib64
+export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:/lib
```
+_Note: depending on your system and architecture `lib` may be `lib64` or something similar._
+
+### 4. Running on GPUs
+
+In order to run a model on GPU, two main changes are required:
+
+1. When saving your TorchScript model, ensure that it is on the GPU. For example, when using [pt2ts.py](utils/pt2ts.py), this can be done by uncommenting the following lines:
+
+```
+device = torch.device('cuda')
+trained_model = trained_model.to(device)
+trained_model.eval()
+trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
+trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)
+```
+
+Note: this also moves the dummy input tensors to the GPU. This is not necessary for saving the model, but the tensors must also be on the GPU to test that the models runs.
+
+
+2. When calling `torch_tensor_from_blob` in Fortran, the device for the input tensor(s), but not the output tensor(s),
+ should be set to `torch_kCUDA`, rather than `torch_kCPU`. This ensures that the inputs are on the same device as the model.
## Examples
diff --git a/examples/1_SimpleNet/README.md b/examples/1_SimpleNet/README.md
index 3bc5e0cb..dc1a9d77 100644
--- a/examples/1_SimpleNet/README.md
+++ b/examples/1_SimpleNet/README.md
@@ -74,10 +74,9 @@ This can be done using the included `CMakeLists.txt` as follows:
```
mkdir build
cd build
-cmake .. -DFTorch_DIR=lib/cmake/ -DCMAKE_BUILD_TYPE=Release
+cmake .. -DCMAKE_PREFIX_PATH= -DCMAKE_BUILD_TYPE=Release
make
```
-Make sure that the `FTorch_DIR` flag points to the `lib/cmake/` folder within the installation of the FTorch library.
To run the compiled code calling the saved SimpleNet TorchScript from Fortran run the
executable with an argument of the saved model file:
diff --git a/examples/1_SimpleNet/simplenet.py b/examples/1_SimpleNet/simplenet.py
index 3b743acf..e93189e6 100644
--- a/examples/1_SimpleNet/simplenet.py
+++ b/examples/1_SimpleNet/simplenet.py
@@ -1,4 +1,4 @@
-"""Module defining a simple PyTorch 'Net' for coupling to Fortran. """
+"""Module defining a simple PyTorch 'Net' for coupling to Fortran."""
import torch
from torch import nn
@@ -15,7 +15,6 @@ def __init__(
Consists of a single Linear layer with weights predefined to
multiply the input by 2.
"""
-
super().__init__()
self._fwd_seq = nn.Sequential(
nn.Linear(5, 5, bias=False),
@@ -38,7 +37,6 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
batch scaled by 2.
"""
-
return self._fwd_seq(batch)
diff --git a/examples/1_SimpleNet/simplenet_infer_fortran b/examples/1_SimpleNet/simplenet_infer_fortran
new file mode 100755
index 00000000..de4a84a8
Binary files /dev/null and b/examples/1_SimpleNet/simplenet_infer_fortran differ
diff --git a/examples/1_SimpleNet/simplenet_infer_fortran.f90 b/examples/1_SimpleNet/simplenet_infer_fortran.f90
index 2f1ff385..199b984c 100644
--- a/examples/1_SimpleNet/simplenet_infer_fortran.f90
+++ b/examples/1_SimpleNet/simplenet_infer_fortran.f90
@@ -1,28 +1,30 @@
program inference
- ! Imports primitives used to interface with C
- use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc
+ ! Import precision info from iso
+ use, intrinsic :: iso_fortran_env, only : sp => real32
+
! Import our library for interfacing with PyTorch
use ftorch
implicit none
-
+
+ ! Set precision for reals
+ integer, parameter :: wp = sp
+
integer :: num_args, ix
character(len=128), dimension(:), allocatable :: args
- ! Set up types of input and output data and the interface with C
+ ! Set up Fortran data structures
+ real(wp), dimension(5), target :: in_data
+ real(wp), dimension(5), target :: out_data
+ integer, parameter :: n_inputs = 1
+ integer :: tensor_layout(1) = [1]
+
+ ! Set up Torch data structures
type(torch_module) :: model
type(torch_tensor), dimension(1) :: in_tensor
type(torch_tensor) :: out_tensor
- real(c_float), dimension(:), allocatable, target :: in_data
- integer(c_int), parameter :: n_inputs = 1
- real(c_float), dimension(:), allocatable, target :: out_data
-
- integer(c_int), parameter :: tensor_dims = 1
- integer(c_int64_t) :: tensor_shape(tensor_dims) = [5]
- integer(c_int) :: tensor_layout(tensor_dims) = [1]
-
! Get TorchScript model file as a command line argument
num_args = command_argument_count()
allocate(args(num_args))
@@ -30,18 +32,14 @@ program inference
call get_command_argument(ix,args(ix))
end do
- ! Allocate one-dimensional input/output arrays, based on multiplication of all input/output dimension sizes
- allocate(in_data(tensor_shape(1)))
- allocate(out_data(tensor_shape(1)))
-
! Initialise data
in_data = [0.0, 1.0, 2.0, 3.0, 4.0]
- ! Create input/output tensors from the above arrays
- in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
- out_tensor = torch_tensor_from_blob(c_loc(out_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
+ ! Create Torch input/output tensors from the above arrays
+ in_tensor(1) = torch_tensor_from_array(in_data, tensor_layout, torch_kCPU)
+ out_tensor = torch_tensor_from_array(out_data, tensor_layout, torch_kCPU)
- ! Load ML model (edit this line to use different models)
+ ! Load ML model
model = torch_module_load(args(1))
! Infer
@@ -52,7 +50,5 @@ program inference
call torch_module_delete(model)
call torch_tensor_delete(in_tensor(1))
call torch_tensor_delete(out_tensor)
- deallocate(in_data)
- deallocate(out_data)
end program inference
diff --git a/examples/1_SimpleNet/simplenet_infer_python.py b/examples/1_SimpleNet/simplenet_infer_python.py
index b3aadc29..54570882 100644
--- a/examples/1_SimpleNet/simplenet_infer_python.py
+++ b/examples/1_SimpleNet/simplenet_infer_python.py
@@ -21,7 +21,6 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
output : torch.Tensor
result of running inference on model with example Tensor input
"""
-
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]).repeat(batch_size, 1)
if device == "cpu":
diff --git a/examples/2_ResNet18/Makefile b/examples/2_ResNet18/Makefile
index 6e886217..570fc181 100644
--- a/examples/2_ResNet18/Makefile
+++ b/examples/2_ResNet18/Makefile
@@ -6,7 +6,7 @@ FC = gfortran
FCFLAGS = -O3 -I/include/ftorch
# link flags
-LDFLAGS = -L/lib64/ -lftorch
+LDFLAGS = -L/lib/ -lftorch
PROGRAM = resnet_infer_fortran
SRC = resnet_infer_fortran.f90
diff --git a/examples/2_ResNet18/README.md b/examples/2_ResNet18/README.md
index e5e90f91..2d5d7aff 100644
--- a/examples/2_ResNet18/README.md
+++ b/examples/2_ResNet18/README.md
@@ -70,10 +70,9 @@ This can be done using the included `CMakeLists.txt` as follows:
```
mkdir build
cd build
-cmake .. -DFTorch_DIR=lib/cmake/ -DCMAKE_BUILD_TYPE=Release
+cmake .. -DCMAKE_PREFIX_PATH= -DCMAKE_BUILD_TYPE=Release
make
```
-Make sure that the `FTorch_DIR` flag points to the `lib/cmake/` folder within the installation of the FTorch library.
To run the compiled code calling the saved ResNet-18 TorchScript from Fortran run the
executable with an argument of the saved model file:
@@ -96,7 +95,7 @@ installation of FTorch as described in the main documentation. Also check that t
You will also likely need to add the location of the `.so` files to your `LD_LIBRARY_PATH`:
```
make
-export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/lib64
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/lib
./resnet_infer_fortran saved_resnet18_model_cpu.pt
```
diff --git a/examples/2_ResNet18/resnet18.py b/examples/2_ResNet18/resnet18.py
index 0f8b070a..0bf9c2bd 100644
--- a/examples/2_ResNet18/resnet18.py
+++ b/examples/2_ResNet18/resnet18.py
@@ -21,7 +21,6 @@ def initialize(precision: torch.dtype) -> torch.nn.Module:
model: torch.nn.Module
Pretrained ResNet-18 model
"""
-
# Set working precision
torch.set_default_dtype(precision)
@@ -96,7 +95,7 @@ def run_model(model: torch.nn.Module, precision: type) -> None:
def print_top_results(output: torch.Tensor) -> None:
- """Prints top 5 results
+ """Print top 5 results.
Parameters
----------
diff --git a/examples/2_ResNet18/resnet_infer_fortran.f90 b/examples/2_ResNet18/resnet_infer_fortran.f90
index dfc012b1..1af256af 100644
--- a/examples/2_ResNet18/resnet_infer_fortran.f90
+++ b/examples/2_ResNet18/resnet_infer_fortran.f90
@@ -1,18 +1,12 @@
program inference
- ! Imports primitives used to interface with C
- use, intrinsic :: iso_c_binding, only: c_sp=>c_float, c_dp=>c_double, c_int64_t, c_loc
- use, intrinsic :: iso_fortran_env, only : sp => real32, dp => real64
+ use, intrinsic :: iso_fortran_env, only : sp => real32
! Import our library for interfacing with PyTorch
use :: ftorch
implicit none
- ! Define working precision for C primitives
- ! Precision must match `wp` in resnet18.py and `wp_torch` in pt2ts.py
- integer, parameter :: c_wp = c_sp
integer, parameter :: wp = sp
- integer, parameter :: torch_wp = torch_kFloat32
call main()
@@ -25,21 +19,21 @@ subroutine main()
integer :: num_args, ix
character(len=128), dimension(:), allocatable :: args
- ! Set up types of input and output data and the interface with C
+ ! Set up types of input and output data
type(torch_module) :: model
type(torch_tensor), dimension(1) :: in_tensor
type(torch_tensor) :: out_tensor
- real(c_wp), dimension(:,:,:,:), allocatable, target :: in_data
- integer(c_int), parameter :: n_inputs = 1
- real(c_wp), dimension(:,:), allocatable, target :: out_data
+ real(wp), dimension(:,:,:,:), allocatable, target :: in_data
+ real(wp), dimension(:,:), allocatable, target :: out_data
+ integer, parameter :: n_inputs = 1
- integer(c_int), parameter :: in_dims = 4
- integer(c_int64_t) :: in_shape(in_dims) = [1, 3, 224, 224]
- integer(c_int) :: in_layout(in_dims) = [1,2,3,4]
- integer(c_int), parameter :: out_dims = 2
- integer(c_int64_t) :: out_shape(out_dims) = [1, 1000]
- integer(c_int) :: out_layout(out_dims) = [1,2]
+ integer, parameter :: in_dims = 4
+ integer :: in_shape(in_dims) = [1, 3, 224, 224]
+ integer :: in_layout(in_dims) = [1,2,3,4]
+ integer, parameter :: out_dims = 2
+ integer :: out_shape(out_dims) = [1, 1000]
+ integer :: out_layout(out_dims) = [1,2]
! Binary file containing input tensor
character(len=*), parameter :: filename = '../data/image_tensor.dat'
@@ -72,8 +66,9 @@ subroutine main()
call load_data(filename, tensor_length, in_data)
! Create input/output tensors from the above arrays
- in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), in_dims, in_shape, torch_wp, torch_kCPU, in_layout)
- out_tensor = torch_tensor_from_blob(c_loc(out_data), out_dims, out_shape, torch_wp, torch_kCPU, out_layout)
+ in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCPU)
+
+ out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU)
! Load ML model (edit this line to use different models)
model = torch_module_load(args(1))
@@ -113,9 +108,9 @@ subroutine load_data(filename, tensor_length, in_data)
character(len=*), intent(in) :: filename
integer, intent(in) :: tensor_length
- real(c_wp), dimension(:,:,:,:), intent(out) :: in_data
+ real(wp), dimension(:,:,:,:), intent(out) :: in_data
- real(c_wp) :: flat_data(tensor_length)
+ real(wp) :: flat_data(tensor_length)
integer :: ios
character(len=100) :: ioerrmsg
@@ -166,7 +161,7 @@ subroutine calc_probs(out_data, probabilities)
implicit none
- real(c_wp), dimension(:,:), intent(in) :: out_data
+ real(wp), dimension(:,:), intent(in) :: out_data
real(wp), dimension(:,:), intent(out) :: probabilities
real(wp) :: prob_sum
diff --git a/examples/n_c_and_cpp/resnet18.py b/examples/n_c_and_cpp/resnet18.py
index 06a0c9cd..4b693d48 100644
--- a/examples/n_c_and_cpp/resnet18.py
+++ b/examples/n_c_and_cpp/resnet18.py
@@ -14,7 +14,6 @@ def initialize():
-------
model : torch.nn.Module
"""
-
# Load a pre-trained PyTorch model
print("Loading pre-trained ResNet-18 model...", end="")
model = torchvision.models.resnet18(pretrained=True)
@@ -35,7 +34,6 @@ def run_model(model):
----------
model : torch.nn.Module
"""
-
print("Running ResNet-18 model for ones...", end="")
dummy_input = torch.ones(1, 3, 224, 224)
output = model(dummy_input)
diff --git a/examples/n_c_and_cpp/resnet_infer_python.py b/examples/n_c_and_cpp/resnet_infer_python.py
index d5b462a3..a78fecca 100644
--- a/examples/n_c_and_cpp/resnet_infer_python.py
+++ b/examples/n_c_and_cpp/resnet_infer_python.py
@@ -22,7 +22,6 @@ def deploy(saved_model, device, batch_size=1):
output : torch.Tensor
result of running inference on model with Tensor of ones
"""
-
image_filename = "data/dog.jpg"
input_image = Image.open(image_filename)
preprocess = torchvision.transforms.Compose(
diff --git a/examples/tensor_tests/CMakeLists.txt b/examples/tensor_tests/CMakeLists.txt
deleted file mode 100644
index 6571f1cb..00000000
--- a/examples/tensor_tests/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
-#policy CMP0076 - target_sources source files are relative to file where target_sources is run
-cmake_policy (SET CMP0076 NEW)
-
-set(PROJECT_NAME test_tensor)
-
-project(${PROJECT_NAME} LANGUAGES Fortran)
-
-# Build in Debug mode if not specified
-if(NOT CMAKE_BUILD_TYPE)
- set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE)
-endif()
-
-find_package(FTorch)
-message(STATUS "Building with Fortran PyTorch coupling")
-
-# Some tests for tensor generation.
-add_executable(test_tensor test_tensor.f90)
-target_link_libraries(test_tensor PRIVATE FTorch::ftorch)
diff --git a/examples/tensor_tests/test_tensor.f90 b/examples/tensor_tests/test_tensor.f90
deleted file mode 100644
index d7bc1410..00000000
--- a/examples/tensor_tests/test_tensor.f90
+++ /dev/null
@@ -1,69 +0,0 @@
-program test_tensor
- use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc
- use ftorch
- implicit none
-
- real(kind=8), dimension(:,:), allocatable, target :: uuu_flattened, vvv_flattened
- real(kind=8), dimension(:,:), allocatable, target :: lat_reshaped, psfc_reshaped
- real(kind=8), dimension(:,:), allocatable, target :: gwfcng_x_flattened, gwfcng_y_flattened
- type(torch_tensor), target :: output_tensor
- integer(c_int), parameter :: dims_1D = 2
- integer(c_int), parameter :: dims_2D = 2
- integer(c_int64_t) :: shape_2D_F(dims_2D), shape_2D_C(dims_2D)
- integer(c_int64_t) :: shape_1D_F(dims_1D), shape_1D_C(dims_1D)
- integer(c_int) :: layout_F(dims_1D), layout_C(dims_1D)
- integer :: imax, jmax, kmax, i, j, k
-
- imax = 1
- jmax = 5
- kmax = 7
-
- shape_2D_F = (/ kmax, imax*jmax /)
- shape_1D_F = (/ 1, imax*jmax /)
- shape_2D_C = (/ imax*jmax, kmax /)
- shape_1D_C = (/ imax*jmax, 1 /)
-
- layout_F = (/ 1, 2 /)
- layout_C = (/ 2, 1 /)
-
- allocate( lat_reshaped(imax*jmax, 1) )
- allocate( uuu_flattened(imax*jmax, kmax) )
- do i = 1, imax*jmax
- lat_reshaped(i, 1) = i
- do k = 1, kmax
- uuu_flattened(i, k) = i + k*100
- end do
- end do
-
- write(*,*) uuu_flattened
-
- output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), &
- dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU, layout_F)
-
- call torch_tensor_print(output_tensor)
-
- output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), &
- dims_2D, shape_2D_F, torch_kFloat64, torch_kCPU, layout_C)
-
- call torch_tensor_print(output_tensor)
-
- shape_2D_F = shape(uuu_flattened)
- output_tensor = torch_tensor_from_array_c_double(uuu_flattened, shape_2D_F, torch_kCPU)
-
- call torch_tensor_print(output_tensor)
-
- output_tensor = torch_tensor_from_array(uuu_flattened, shape_2D_F, torch_kCPU)
-
- call torch_tensor_print(output_tensor)
-
- ! output_tensor = torch_tensor_zeros( &
- ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU)
-
- ! call torch_tensor_print(output_tensor)
-
- ! output_tensor = torch_tensor_ones( &
- ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU)
-
- ! call torch_tensor_print(output_tensor)
-
-end program test_tensor
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index fd12b3fd..b2018fe0 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -12,6 +12,16 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
+include(CheckLanguage)
+if(ENABLE_CUDA)
+ check_language(CUDA)
+ if(CMAKE_CUDA_COMPILER)
+ enable_language(CUDA)
+ else()
+ message(WARNING "No CUDA support")
+ endif()
+endif()
+
# Set RPATH behaviour
set(CMAKE_SKIP_RPATH FALSE)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
@@ -36,12 +46,15 @@ find_package(Torch REQUIRED)
# Library with C and Fortran bindings
add_library(${LIB_NAME} SHARED ctorch.cpp ftorch.f90)
+# Add an alias FTorch::ftorch for the library
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES
PUBLIC_HEADER "ctorch.h"
Fortran_MODULE_DIRECTORY "${CMAKE_BINARY_DIR}/modules"
)
+# Link TorchScript
target_link_libraries(${LIB_NAME} PRIVATE ${TORCH_LIBRARIES})
+# Include the Fortran mod files in the library
target_include_directories(${LIB_NAME}
PUBLIC
$
@@ -61,7 +74,7 @@ install(TARGETS "${LIB_NAME}"
install(EXPORT ${PROJECT_NAME}
FILE ${PROJECT_NAME}Config.cmake
NAMESPACE ${PROJECT_NAME}::
- DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake
+ DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}
)
# Install Fortran module files
diff --git a/src/ctorch.cpp b/src/ctorch.cpp
index c5578e82..58bca7f8 100644
--- a/src/ctorch.cpp
+++ b/src/ctorch.cpp
@@ -52,7 +52,7 @@ torch_tensor_t torch_zeros(int ndim, const int64_t* shape, torch_data_t dtype,
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::zeros(
- vshape, torch::dtype(get_dtype(dtype)).device(get_device(device)));
+ vshape, torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
@@ -74,7 +74,7 @@ torch_tensor_t torch_ones(int ndim, const int64_t* shape, torch_data_t dtype,
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::ones(
- vshape, torch::dtype(get_dtype(dtype)).device(get_device(device)));
+ vshape, torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
@@ -96,7 +96,7 @@ torch_tensor_t torch_empty(int ndim, const int64_t* shape, torch_data_t dtype,
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::empty(
- vshape, torch::dtype(get_dtype(dtype)).device(get_device(device)));
+ vshape, torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
@@ -122,7 +122,7 @@ torch_tensor_t torch_from_blob(void* data, int ndim, const int64_t* shape,
tensor = new torch::Tensor;
*tensor = torch::from_blob(
data, vshape,
- torch::dtype(get_dtype(dtype)).device(get_device(device)));
+ torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
@@ -150,7 +150,7 @@ torch_tensor_t torch_from_blob(void* data, int ndim, const int64_t* shape,
tensor = new torch::Tensor;
*tensor = torch::from_blob(
data, vshape, vstrides,
- torch::dtype(get_dtype(dtype)).device(get_device(device)));
+ torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
@@ -230,9 +230,6 @@ void torch_jit_module_forward(const torch_jit_script_module_t module,
std::cerr << "[ERROR]: " << e.what() << std::endl;
exit(EXIT_FAILURE);
}
- // FIXME: this should be the responsibility of the user
- if (out->is_cuda())
- torch::cuda::synchronize();
}
void torch_jit_module_delete(torch_jit_script_module_t module)
diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 340d30ac..945673cf 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -1,311 +1,1007 @@
-module ftorch
- !! The ftorch module containing wrappers to access libtorch
-
- use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
- c_float, c_double, c_char, c_ptr, c_null_ptr
- implicit none
-
- !> Type for holding a torch neural net (nn.Module).
- type torch_module
- type(c_ptr) :: p = c_null_ptr !! pointer to the neural net module in memory
- end type torch_module
-
- !> Type for holding a torch tensor.
- type torch_tensor
- type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory
- end type torch_tensor
-
- ! From c_torch.h (torch_data_t)
- enum, bind(c)
- enumerator :: torch_kUInt8 = 0
- enumerator :: torch_kInt8 = 1
- enumerator :: torch_kInt16 = 2
- enumerator :: torch_kInt32 = 3
- enumerator :: torch_kInt64 = 4
- enumerator :: torch_kFloat16 = 5
- enumerator :: torch_kFloat32 = 6
- enumerator :: torch_kFloat64 = 7
- end enum
-
- ! From c_torch.h (torch_device_t)
- enum, bind(c)
- enumerator :: torch_kCPU = 0
- enumerator :: torch_kCUDA = 1
- end enum
-
- ! Interface for calculating tensor from array for different possible input types
- interface torch_tensor_from_array
- module procedure torch_tensor_from_array_c_float
- module procedure torch_tensor_from_array_c_double
- ! module procedure torch_tensor_from_array_c_int8_t
- ! module procedure torch_tensor_from_array_c_int16_t
- ! module procedure torch_tensor_from_array_c_int32_t
- ! module procedure torch_tensor_from_array_c_int64_t
- end interface
+!| Main module for FTorch containing types and procedures.
+! Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html).
+!
+! * License
+! FTorch is released under an MIT license.
+! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
+! file for details.
-contains
+module ftorch
- ! Torch Tensor API
- !> Exposes the given data as a tensor without taking ownership of the original data.
- !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
- function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
+ use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
+ c_float, c_double, c_char, c_ptr, c_null_ptr
+ use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
+
+ implicit none
+
+ !> Type for holding a torch neural net (nn.Module).
+ type torch_module
+ type(c_ptr) :: p = c_null_ptr !! pointer to the neural net module in memory
+ end type torch_module
+
+ !> Type for holding a Torch tensor.
+ type torch_tensor
+ type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory
+ end type torch_tensor
+
+ !| Enumerator for Torch data types
+ ! From c_torch.h (torch_data_t)
+ ! Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran
+ enum, bind(c)
+ enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
+ enumerator :: torch_kInt8 = 1
+ enumerator :: torch_kInt16 = 2
+ enumerator :: torch_kInt32 = 3
+ enumerator :: torch_kInt64 = 4
+ enumerator :: torch_kFloat16 = 5 ! not supported in Fortran
+ enumerator :: torch_kFloat32 = 6
+ enumerator :: torch_kFloat64 = 7
+ end enum
+
+
+ !| Enumerator for Torch devices
+ ! From c_torch.h (torch_device_t)
+ enum, bind(c)
+ enumerator :: torch_kCPU = 0
+ enumerator :: torch_kCUDA = 1
+ end enum
+
+ !> Interface for directing `torch_tensor_from_array` to possible input types and ranks
+ interface torch_tensor_from_array
+ module procedure torch_tensor_from_array_int8_1d
+ module procedure torch_tensor_from_array_int8_2d
+ module procedure torch_tensor_from_array_int8_3d
+ module procedure torch_tensor_from_array_int8_4d
+ module procedure torch_tensor_from_array_int16_1d
+ module procedure torch_tensor_from_array_int16_2d
+ module procedure torch_tensor_from_array_int16_3d
+ module procedure torch_tensor_from_array_int16_4d
+ module procedure torch_tensor_from_array_int32_1d
+ module procedure torch_tensor_from_array_int32_2d
+ module procedure torch_tensor_from_array_int32_3d
+ module procedure torch_tensor_from_array_int32_4d
+ module procedure torch_tensor_from_array_int64_1d
+ module procedure torch_tensor_from_array_int64_2d
+ module procedure torch_tensor_from_array_int64_3d
+ module procedure torch_tensor_from_array_int64_4d
+ module procedure torch_tensor_from_array_real32_1d
+ module procedure torch_tensor_from_array_real32_2d
+ module procedure torch_tensor_from_array_real32_3d
+ module procedure torch_tensor_from_array_real32_4d
+ module procedure torch_tensor_from_array_real64_1d
+ module procedure torch_tensor_from_array_real64_2d
+ module procedure torch_tensor_from_array_real64_3d
+ module procedure torch_tensor_from_array_real64_4d
+ end interface
+
+ interface
+ function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
+ bind(c, name = 'torch_from_blob')
use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
! Arguments
- type(c_ptr), intent(in) :: data !! Pointer to data
- integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
- integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
- integer(c_int), intent(in) :: dtype !! Data type of the tensor
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data
- type(torch_tensor) :: tensor !! Returned tensor
-
- integer(c_int) :: i !! loop index
- integer(c_int64_t) :: strides(ndims) !! Strides for accessing data
-
- interface
- function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
- bind(c, name = 'torch_from_blob')
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
- type(c_ptr), value, intent(in) :: data
- integer(c_int), value, intent(in) :: ndims
- integer(c_int64_t), intent(in) :: tensor_shape(*)
- integer(c_int64_t), intent(in) :: strides(*)
- integer(c_int), value, intent(in) :: dtype
- integer(c_int), value, intent(in) :: device
- type(c_ptr) :: tensor
- end function torch_from_blob_c
- end interface
-
- strides(layout(1)) = 1
- do i = 2, ndims
- strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
- end do
- tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
- end function torch_tensor_from_blob
-
- !> This routine will take an (i, j, k) array and return an (k, j, i) tensor
- !> it is invoked from a set of interfaces `torch_tensor_from_array_dtype`
- function t_t_from_array(data_arr, tensor_shape, dtype, device) result(tensor)
-
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
+ type(c_ptr), value, intent(in) :: data
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int64_t), intent(in) :: strides(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor_p
+ end function torch_from_blob_c
+ end interface
+
+contains
+
+ ! Torch Tensor API
+ !| Exposes the given data as a tensor without taking ownership of the original data.
+ ! This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+ function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ type(c_ptr), intent(in) :: data !! Pointer to data
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ integer(c_int) :: i !! loop index
+ integer(c_int64_t) :: strides(ndims) !! Strides for accessing data
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
+ end do
+ tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
+ end function torch_tensor_from_blob
+
+ !> Returns a tensor filled with the scalar value 1.
+ function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ interface
+ function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
+ bind(c, name = 'torch_ones')
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor
+ end function torch_ones_c
+ end interface
+
+ tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
+ end function torch_tensor_ones
+
+ !> Returns a tensor filled with the scalar value 0.
+ function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ interface
+ function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
+ bind(c, name = 'torch_zeros')
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor
+ end function torch_zeros_c
+ end interface
+
+ tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
+ end function torch_tensor_zeros
+
+ !> Prints the contents of a tensor.
+ subroutine torch_tensor_print(tensor)
+ type(torch_tensor), intent(in) :: tensor !! Input tensor
+
+ interface
+ subroutine torch_tensor_print_c(tensor) &
+ bind(c, name = 'torch_tensor_print')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: tensor
+ end subroutine torch_tensor_print_c
+ end interface
+
+ call torch_tensor_print_c(tensor%p)
+ end subroutine torch_tensor_print
+
+ !> Deallocates a tensor.
+ subroutine torch_tensor_delete(tensor)
+ type(torch_tensor), intent(in) :: tensor !! Input tensor
+
+ interface
+ subroutine torch_tensor_delete_c(tensor) &
+ bind(c, name = 'torch_tensor_delete')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: tensor
+ end subroutine torch_tensor_delete_c
+ end interface
+
+ call torch_tensor_delete_c(tensor%p)
+ end subroutine torch_tensor_delete
+
+ ! Torch Module API
+ !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
+ function torch_module_load(filename) result(module)
+ use, intrinsic :: iso_c_binding, only : c_null_char
+ character(*), intent(in) :: filename !! Filename of Torch Script module
+ type(torch_module) :: module !! Returned deserialized module
+
+ interface
+ function torch_jit_load_c(filename) result(module) &
+ bind(c, name = 'torch_jit_load')
+ use, intrinsic :: iso_c_binding, only : c_char, c_ptr
+ character(c_char), intent(in) :: filename(*)
+ type(c_ptr) :: module
+ end function torch_jit_load_c
+ end interface
+
+ ! Need to append c_null_char at end of filename
+ module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
+ end function torch_module_load
+
+ !> Performs a forward pass of the module with the input tensors
+ subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+ use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+ type(torch_module), intent(in) :: module !! Module
+ type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors
+ type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+ integer(c_int) :: n_inputs
+
+ integer :: i
+ type(c_ptr), dimension(n_inputs), target :: input_ptrs
+
+ interface
+ subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+ output_tensor) &
+ bind(c, name = 'torch_jit_module_forward')
+ use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+ type(c_ptr), value, intent(in) :: module
+ type(c_ptr), value, intent(in) :: input_tensors
+ integer(c_int), value, intent(in) :: n_inputs
+ type(c_ptr), value, intent(in) :: output_tensor
+ end subroutine torch_jit_module_forward_c
+ end interface
+
+ ! Assign array of pointers to the input tensors
+ do i = 1, n_inputs
+ input_ptrs(i) = input_tensors(i)%p
+ end do
+
+ call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+ end subroutine torch_module_forward
+
+ !> Deallocates a Torch Script module
+ subroutine torch_module_delete(module)
+ type(torch_module), intent(in) :: module !! Module to deallocate
+
+ interface
+ subroutine torch_jit_module_delete_c(module) &
+ bind(c, name = 'torch_jit_module_delete')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: module
+ end subroutine torch_jit_module_delete_c
+ end interface
+
+ call torch_jit_module_delete_c(module%p)
+ end subroutine torch_module_delete
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int8`
+ function torch_tensor_from_array_int8_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int8
+
+ ! inputs
+ integer(kind=int8), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int8_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int8`
+ function torch_tensor_from_array_int8_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int8
+
+ ! inputs
+ integer(kind=int8), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int8_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int8`
+ function torch_tensor_from_array_int8_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int8
+
+ ! inputs
+ integer(kind=int8), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int8_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int8`
+ function torch_tensor_from_array_int8_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int8
+
+ ! inputs
+ integer(kind=int8), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int8_4d
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int16`
+ function torch_tensor_from_array_int16_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int16
+
+ ! inputs
+ integer(kind=int16), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int16_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int16`
+ function torch_tensor_from_array_int16_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int16
+
+ ! inputs
+ integer(kind=int16), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int16_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int16`
+ function torch_tensor_from_array_int16_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int16
+
+ ! inputs
+ integer(kind=int16), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int16_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int16`
+ function torch_tensor_from_array_int16_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int16
+
+ ! inputs
+ integer(kind=int16), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int16_4d
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int32`
+ function torch_tensor_from_array_int32_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int32
+
+ ! inputs
+ integer(kind=int32), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int32_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int32`
+ function torch_tensor_from_array_int32_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int32
+
+ ! inputs
+ integer(kind=int32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int32_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int32`
+ function torch_tensor_from_array_int32_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int32
+
+ ! inputs
+ integer(kind=int32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int32_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int32`
+ function torch_tensor_from_array_int32_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int32
+
+ ! inputs
+ integer(kind=int32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int32_4d
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int64`
+ function torch_tensor_from_array_int64_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int64
+
+ ! inputs
+ integer(kind=int64), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int64_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int64`
+ function torch_tensor_from_array_int64_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int64
+
+ ! inputs
+ integer(kind=int64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int64_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int64`
+ function torch_tensor_from_array_int64_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int64
+
+ ! inputs
+ integer(kind=int64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int64_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int64`
+ function torch_tensor_from_array_int64_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : int64
+
+ ! inputs
+ integer(kind=int64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_int64_4d
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real32`
+ function torch_tensor_from_array_real32_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real32
+
+ ! inputs
+ real(kind=real32), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real32_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real32`
+ function torch_tensor_from_array_real32_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real32
+
+ ! inputs
+ real(kind=real32), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real32_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real32`
+ function torch_tensor_from_array_real32_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real32
+
+ ! inputs
+ real(kind=real32), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real32_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real32`
+ function torch_tensor_from_array_real32_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real32
+
+ ! inputs
+ real(kind=real32), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real32_4d
+
+ !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real64`
+ function torch_tensor_from_array_real64_1d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real64
+
+ ! inputs
+ real(kind=real64), intent(in), target :: data_in(:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(1) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(1) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+ integer(c_int64_t) :: strides(1) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 1 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real64_1d
+
+ !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real64`
+ function torch_tensor_from_array_real64_2d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real64
+
+ ! inputs
+ real(kind=real64), intent(in), target :: data_in(:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(2) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(2) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+ integer(c_int64_t) :: strides(2) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 2 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real64_2d
+
+ !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real64`
+ function torch_tensor_from_array_real64_3d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real64
+
+ ! inputs
+ real(kind=real64), intent(in), target :: data_in(:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(3) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(3) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+ integer(c_int64_t) :: strides(3) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 3 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real64_3d
+
+ !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real64`
+ function torch_tensor_from_array_real64_4d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : real64
+
+ ! inputs
+ real(kind=real64), intent(in), target :: data_in(:,:,:,:) !! Input data that tensor will point at
+ integer, intent(in) :: layout(4) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(4) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+ integer(c_int64_t) :: strides(4) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = 4 !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_real64_4d
- ! Arguments
- type(c_ptr), intent(in) :: data_arr !! Pointer to data
- integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor
- integer(c_int), intent(in) :: dtype !! Data type of the tensor
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- type(torch_tensor) :: tensor !! Returned tensor
-
- integer(c_int) :: i !! loop index
- integer(c_int64_t), allocatable :: strides(:) !! Strides for accessing data
- integer(c_int), allocatable :: layout(:) !! Layout for strides for accessing data
- integer(c_int) :: ndims !! Number of dimensions of the tensor
-
- interface
- function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
- bind(c, name = 'torch_from_blob')
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
- type(c_ptr), value, intent(in) :: data
- integer(c_int), value, intent(in) :: ndims
- integer(c_int64_t), intent(in) :: tensor_shape(*)
- integer(c_int64_t), intent(in) :: strides(*)
- integer(c_int), value, intent(in) :: dtype
- integer(c_int), value, intent(in) :: device
- type(c_ptr) :: tensor
- end function torch_from_blob_c
- end interface
-
- ndims = size(tensor_shape)
-
- allocate(strides(ndims))
- allocate(layout(ndims))
-
- ! Fortran Layout
- do i = 1, ndims
- layout(i) = i
- end do
-
- strides(layout(1)) = 1
- do i = 2, ndims
- strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
- end do
-
- tensor%p = torch_from_blob_c(data_arr, ndims, tensor_shape, strides, dtype, device)
-
- deallocate(strides)
- deallocate(layout)
-
- end function t_t_from_array
-
- !> Returns a tensor filled with the scalar value 1.
- function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
- integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
- integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
- integer(c_int), intent(in) :: dtype !! Data type of the tensor
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- type(torch_tensor) :: tensor !! Returned tensor
-
- interface
- function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
- bind(c, name = 'torch_ones')
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
- integer(c_int), value, intent(in) :: ndims
- integer(c_int64_t), intent(in) :: tensor_shape(*)
- integer(c_int), value, intent(in) :: dtype
- integer(c_int), value, intent(in) :: device
- type(c_ptr) :: tensor
- end function torch_ones_c
- end interface
-
- tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
- end function torch_tensor_ones
-
- !> Returns a tensor filled with the scalar value 0.
- function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
- integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
- integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
- integer(c_int), intent(in) :: dtype !! Data type of the tensor
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- type(torch_tensor) :: tensor !! Returned tensor
-
- interface
- function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
- bind(c, name = 'torch_zeros')
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
- integer(c_int), value, intent(in) :: ndims
- integer(c_int64_t), intent(in) :: tensor_shape(*)
- integer(c_int), value, intent(in) :: dtype
- integer(c_int), value, intent(in) :: device
- type(c_ptr) :: tensor
- end function torch_zeros_c
- end interface
-
- tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
- end function torch_tensor_zeros
-
- !> Prints the contents of a tensor.
- subroutine torch_tensor_print(tensor)
- type(torch_tensor), intent(in) :: tensor !! Input tensor
-
- interface
- subroutine torch_tensor_print_c(tensor) &
- bind(c, name = 'torch_tensor_print')
- use, intrinsic :: iso_c_binding, only : c_ptr
- type(c_ptr), value, intent(in) :: tensor
- end subroutine torch_tensor_print_c
- end interface
-
- call torch_tensor_print_c(tensor%p)
- end subroutine torch_tensor_print
-
- !> Deallocates a tensor.
- subroutine torch_tensor_delete(tensor)
- type(torch_tensor), intent(in) :: tensor !! Input tensor
-
- interface
- subroutine torch_tensor_delete_c(tensor) &
- bind(c, name = 'torch_tensor_delete')
- use, intrinsic :: iso_c_binding, only : c_ptr
- type(c_ptr), value, intent(in) :: tensor
- end subroutine torch_tensor_delete_c
- end interface
-
- call torch_tensor_delete_c(tensor%p)
- end subroutine torch_tensor_delete
-
- ! Torch Module API
- !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
- function torch_module_load(filename) result(module)
- use, intrinsic :: iso_c_binding, only : c_null_char
- character(*), intent(in) :: filename !! Filename of Torch Script module
- type(torch_module) :: module !! Returned deserialized module
-
- interface
- function torch_jit_load_c(filename) result(module) &
- bind(c, name = 'torch_jit_load')
- use, intrinsic :: iso_c_binding, only : c_char, c_ptr
- character(c_char), intent(in) :: filename(*)
- type(c_ptr) :: module
- end function torch_jit_load_c
- end interface
-
- ! Need to append c_null_char at end of filename
- module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
- end function torch_module_load
-
- !> Performs a forward pass of the module with the input tensors
- subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
- use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
- type(torch_module), intent(in) :: module !! Module
- type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors
- type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
- integer(c_int) :: n_inputs !! Number of tensors in `input_tensors`
-
- integer :: i
- type(c_ptr), dimension(n_inputs), target :: input_ptrs
-
- interface
- subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
- output_tensor) &
- bind(c, name = 'torch_jit_module_forward')
- use, intrinsic :: iso_c_binding, only : c_ptr, c_int
- type(c_ptr), value, intent(in) :: module
- type(c_ptr), value, intent(in) :: input_tensors
- integer(c_int), value, intent(in) :: n_inputs
- type(c_ptr), value, intent(in) :: output_tensor
- end subroutine torch_jit_module_forward_c
- end interface
-
- ! Assign array of pointers to the input tensors
- do i = 1, n_inputs
- input_ptrs(i) = input_tensors(i)%p
- end do
-
- call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
- end subroutine torch_module_forward
-
- !> Deallocates a Torch Script module
- subroutine torch_module_delete(module)
- type(torch_module), intent(in) :: module !! Module
-
- interface
- subroutine torch_jit_module_delete_c(module) &
- bind(c, name = 'torch_jit_module_delete')
- use, intrinsic :: iso_c_binding, only : c_ptr
- type(c_ptr), value, intent(in) :: module
- end subroutine torch_jit_module_delete_c
- end interface
-
- call torch_jit_module_delete_c(module%p)
- end subroutine torch_module_delete
-
- ! Series of interface functions
- function torch_tensor_from_array_c_double(data_arr, tensor_shape, device) result(tensor)
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
- real(c_double), intent(in), target :: data_arr(*) !! Fortran array of data
- ! real(c_double), intent(in), target :: data_arr(*) !! Fortran array of data
- integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor
- integer(c_int), parameter :: dtype = torch_kFloat64
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- type(torch_tensor) :: tensor !! Returned tensor
-
- tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
-
- end function torch_tensor_from_array_c_double
-
- function torch_tensor_from_array_c_float(data_arr, tensor_shape, device) result(tensor)
- use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
- real(c_float), intent(in), target :: data_arr(*) !! Fortran array of data
- integer(c_int64_t), intent(in) :: tensor_shape(:) !! Shape of the tensor
- integer(c_int), parameter :: dtype = torch_kFloat32
- integer(c_int), intent(in) :: device !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
- type(torch_tensor) :: tensor !! Returned tensor
-
- tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
-
- end function torch_tensor_from_array_c_float
end module ftorch
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
new file mode 100644
index 00000000..16371e48
--- /dev/null
+++ b/src/ftorch.fypp
@@ -0,0 +1,296 @@
+#:def ranksuffix(RANK)
+$:'' if RANK == 0 else '(' + ':' + ',:' * (RANK - 1) + ')'
+#:enddef ranksuffix
+#:set PRECISIONS = ['int8', 'int16', 'int32', 'int64', 'real32', 'real64']
+#:set C_PRECISIONS = ['c_int8_t', 'c_int16_t', 'c_int32_t', 'c_int64_t', 'c_float', 'c_double']
+#:set C_PRECISIONS = dict(zip(PRECISIONS, C_PRECISIONS))
+#:set ENUMS = dict(zip(PRECISIONS, ['torch_kInt8', 'torch_kInt16', 'torch_kInt32', 'torch_kInt64', 'torch_kFloat32', 'torch_kFloat64']))
+#:set RANKS = range(1, 5)
+#:def enum_from_prec(PRECISION)
+$:ENUMS[PRECISION]
+#:enddef enum_from_prec
+#:def c_prec(PRECISION)
+$:C_PRECISIONS[PRECISION]
+#:enddef c_prec
+#:def f_type(PRECISION)
+$:'integer' if PRECISION[:3] == 'int' else 'real'
+#:enddef f_type
+!| Main module for FTorch containing types and procedures.
+! Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html).
+!
+! * License
+! FTorch is released under an MIT license.
+! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
+! file for details.
+
+module ftorch
+
+ use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
+ c_float, c_double, c_char, c_ptr, c_null_ptr
+ use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
+
+ implicit none
+
+ !> Type for holding a torch neural net (nn.Module).
+ type torch_module
+ type(c_ptr) :: p = c_null_ptr !! pointer to the neural net module in memory
+ end type torch_module
+
+ !> Type for holding a Torch tensor.
+ type torch_tensor
+ type(c_ptr) :: p = c_null_ptr !! pointer to the tensor in memory
+ end type torch_tensor
+
+ !| Enumerator for Torch data types
+ ! From c_torch.h (torch_data_t)
+ ! Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran
+ enum, bind(c)
+ enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
+ enumerator :: torch_kInt8 = 1
+ enumerator :: torch_kInt16 = 2
+ enumerator :: torch_kInt32 = 3
+ enumerator :: torch_kInt64 = 4
+ enumerator :: torch_kFloat16 = 5 ! not supported in Fortran
+ enumerator :: torch_kFloat32 = 6
+ enumerator :: torch_kFloat64 = 7
+ end enum
+
+
+ !| Enumerator for Torch devices
+ ! From c_torch.h (torch_device_t)
+ enum, bind(c)
+ enumerator :: torch_kCPU = 0
+ enumerator :: torch_kCUDA = 1
+ end enum
+
+ !> Interface for directing `torch_tensor_from_array` to possible input types and ranks
+ interface torch_tensor_from_array
+ #:for PREC in PRECISIONS
+ #:for RANK in RANKS
+ module procedure torch_tensor_from_array_${PREC}$_${RANK}$d
+ #:endfor
+ #:endfor
+ end interface
+
+ interface
+ function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
+ bind(c, name = 'torch_from_blob')
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+
+ ! Arguments
+ type(c_ptr), value, intent(in) :: data
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int64_t), intent(in) :: strides(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor_p
+ end function torch_from_blob_c
+ end interface
+
+contains
+
+ ! Torch Tensor API
+ !| Exposes the given data as a tensor without taking ownership of the original data.
+ ! This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+ function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ type(c_ptr), intent(in) :: data !! Pointer to data
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ integer(c_int), intent(in) :: layout(*) !! Layout for strides for accessing data
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ integer(c_int) :: i !! loop index
+ integer(c_int64_t) :: strides(ndims) !! Strides for accessing data
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
+ end do
+ tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
+ end function torch_tensor_from_blob
+
+ !> Returns a tensor filled with the scalar value 1.
+ function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ interface
+ function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
+ bind(c, name = 'torch_ones')
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor
+ end function torch_ones_c
+ end interface
+
+ tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
+ end function torch_tensor_ones
+
+ !> Returns a tensor filled with the scalar value 0.
+ function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+ integer(c_int), intent(in) :: ndims !! Number of dimensions of the tensor
+ integer(c_int64_t), intent(in) :: tensor_shape(*) !! Shape of the tensor
+ integer(c_int), intent(in) :: dtype !! Data type of the tensor
+ integer(c_int), intent(in) :: device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ interface
+ function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
+ bind(c, name = 'torch_zeros')
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+ integer(c_int), value, intent(in) :: ndims
+ integer(c_int64_t), intent(in) :: tensor_shape(*)
+ integer(c_int), value, intent(in) :: dtype
+ integer(c_int), value, intent(in) :: device
+ type(c_ptr) :: tensor
+ end function torch_zeros_c
+ end interface
+
+ tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
+ end function torch_tensor_zeros
+
+ !> Prints the contents of a tensor.
+ subroutine torch_tensor_print(tensor)
+ type(torch_tensor), intent(in) :: tensor !! Input tensor
+
+ interface
+ subroutine torch_tensor_print_c(tensor) &
+ bind(c, name = 'torch_tensor_print')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: tensor
+ end subroutine torch_tensor_print_c
+ end interface
+
+ call torch_tensor_print_c(tensor%p)
+ end subroutine torch_tensor_print
+
+ !> Deallocates a tensor.
+ subroutine torch_tensor_delete(tensor)
+ type(torch_tensor), intent(in) :: tensor !! Input tensor
+
+ interface
+ subroutine torch_tensor_delete_c(tensor) &
+ bind(c, name = 'torch_tensor_delete')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: tensor
+ end subroutine torch_tensor_delete_c
+ end interface
+
+ call torch_tensor_delete_c(tensor%p)
+ end subroutine torch_tensor_delete
+
+ ! Torch Module API
+ !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
+ function torch_module_load(filename) result(module)
+ use, intrinsic :: iso_c_binding, only : c_null_char
+ character(*), intent(in) :: filename !! Filename of Torch Script module
+ type(torch_module) :: module !! Returned deserialized module
+
+ interface
+ function torch_jit_load_c(filename) result(module) &
+ bind(c, name = 'torch_jit_load')
+ use, intrinsic :: iso_c_binding, only : c_char, c_ptr
+ character(c_char), intent(in) :: filename(*)
+ type(c_ptr) :: module
+ end function torch_jit_load_c
+ end interface
+
+ ! Need to append c_null_char at end of filename
+ module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
+ end function torch_module_load
+
+ !> Performs a forward pass of the module with the input tensors
+ subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+ use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+ type(torch_module), intent(in) :: module !! Module
+ type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors
+ type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+ integer(c_int) :: n_inputs
+
+ integer :: i
+ type(c_ptr), dimension(n_inputs), target :: input_ptrs
+
+ interface
+ subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+ output_tensor) &
+ bind(c, name = 'torch_jit_module_forward')
+ use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+ type(c_ptr), value, intent(in) :: module
+ type(c_ptr), value, intent(in) :: input_tensors
+ integer(c_int), value, intent(in) :: n_inputs
+ type(c_ptr), value, intent(in) :: output_tensor
+ end subroutine torch_jit_module_forward_c
+ end interface
+
+ ! Assign array of pointers to the input tensors
+ do i = 1, n_inputs
+ input_ptrs(i) = input_tensors(i)%p
+ end do
+
+ call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+ end subroutine torch_module_forward
+
+ !> Deallocates a Torch Script module
+ subroutine torch_module_delete(module)
+ type(torch_module), intent(in) :: module !! Module to deallocate
+
+ interface
+ subroutine torch_jit_module_delete_c(module) &
+ bind(c, name = 'torch_jit_module_delete')
+ use, intrinsic :: iso_c_binding, only : c_ptr
+ type(c_ptr), value, intent(in) :: module
+ end subroutine torch_jit_module_delete_c
+ end interface
+
+ call torch_jit_module_delete_c(module%p)
+ end subroutine torch_module_delete
+
+ #:for PREC in PRECISIONS
+ #:for RANK in RANKS
+ !> Return a Torch tensor pointing to data_in array of rank ${RANK}$ containing data of type `${PREC}$`
+ function torch_tensor_from_array_${PREC}$_${RANK}$d(data_in, layout, c_device) result(tensor)
+ use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+ use, intrinsic :: iso_fortran_env, only : ${PREC}$
+
+ ! inputs
+ ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$ !! Input data that tensor will point at
+ integer, intent(in) :: layout(${RANK}$) !! Control order of indices
+ integer(c_int), intent(in) :: c_device !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+ ! output tensory
+ type(torch_tensor) :: tensor !! Returned tensor
+
+ ! local data
+ integer(c_int64_t) :: c_tensor_shape(${RANK}$) !! Shape of the tensor
+ integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! Data type
+ integer(c_int64_t) :: strides(${RANK}$) !! Strides for accessing data
+ integer(c_int), parameter :: ndims = ${RANK}$ !! Number of dimension of input data
+ integer :: i
+
+ c_tensor_shape = shape(data_in)
+
+ strides(layout(1)) = 1
+ do i = 2, ndims
+ strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+ end do
+
+ tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+ end function torch_tensor_from_array_${PREC}$_${RANK}$d
+
+ #:endfor
+ #:endfor
+
+end module ftorch
diff --git a/utils/README.md b/utils/README.md
index 82693903..3b909e41 100644
--- a/utils/README.md
+++ b/utils/README.md
@@ -11,9 +11,9 @@ Dependencies:
- PyTorch
### Usage
-1. Create and activate a virtual environment with PyTorch and any dependencied for your model.
-2. Place the `pt2ts.py` script in the save folder as your model files.
-3. Import your model into `pt2ts.py`
-4. Run with `python3 pt2ts.py`
+1. Create and activate a virtual environment with PyTorch and any dependencies for your model.
+2. Place the `pt2ts.py` script in the same folder as your model files.
+3. Import your model into `pt2ts.py` and amend options as necessary (search for `FPTLIB-TODO`).
+4. Run with `python3 pt2ts.py`.
The model will be saved in the location from which `pt2ts.py` is running.