Skip to content

Commit

Permalink
Resolve merge conflicts with latest updates to main.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatkinson1000 committed Nov 21, 2023
2 parents 865ba88 + 530ab25 commit 4f667bb
Show file tree
Hide file tree
Showing 22 changed files with 1,542 additions and 487 deletions.
69 changes: 69 additions & 0 deletions .githooks/pre-commit
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash
#
# A hook script to verify what is about to be committed.
# Called by "git commit" with no arguments. The hook should
# exit with non-zero status after issuing an appropriate message if
# it wants to stop the commit.

# Fail immediately at first issue with the relevant exit status.
set -eo pipefail

# ===================================================================

if git rev-parse --verify HEAD >/dev/null 2>&1
then
against=HEAD
else
# Initial commit: diff against an empty tree object
against=$(git hash-object -t tree /dev/null)
fi

# ===================================================================

# Check that ftorch.90 is not modified and staged alone.
git diff --cached --name-only | if grep --quiet "ftorch.f90"; then
git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then
cat <<\EOF
Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed.
ftorch.90 should be generated from ftorch.fypp using fypp.
Please restore ftorch.f90 and make your modifications to ftorch.fypp instead.
EOF
exit 1
fi
fi

# Check to see if ftorch.fypp has been modified AND is staged.
git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then

# Check that ftorch.90 is also modified and staged.
git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then
cat <<\EOF
Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not.
ftorch.90 should be generated from ftorch.fypp and both committed together.
Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together.
EOF
exit 1
else
# Check fypp available, and raise error and exit if not.
if ! command -v fypp &> /dev/null; then
cat <<\EOF
echo "Error: Could not find fypp to run on ftorch.fypp.
Please install fypp using "pip install fypp" and then try committing again.
EOF
exit 1
fi

# If fypp is available and both .f90 and .fypp staged, check they match.
fypp src/ftorch.fypp src/ftorch.f90_tmp
if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then
rm src/ftorch.f90_tmp
cat <<\EOF
Error: The code in ftorch.f90 does not match that expected from ftorch.fypp.
Please re-run fypp on ftorch.fypp to ensure consistency before committing.
EOF
exit 1
else
rm src/ftorch.f90_tmp
fi
fi
fi
27 changes: 27 additions & 0 deletions .github/workflows/fypp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: fypp-checks

on:
# run on every push
push:

jobs:
various:
name: FYPP checks - runs check on fypp and f90 files
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- run: pip install fypp

- name: Check fypp matches f90
run: |
fypp src/ftorch.fypp src/temp.f90_temp
if ! diff -q src/ftorch.f90 src/temp.f90_temp; then
echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp."
echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit."
exit 1
else
exit 0
fi
32 changes: 32 additions & 0 deletions .github/workflows/python_qc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: python-code-qc

on:
# run on every push to main
push:
branches:
- main
# run on every push (not commit) to a PR, plus open/reopen
pull_request:
types:
- synchronize
- opened
- reopened

jobs:
various:
name: Python Code QC (Black, pydocstyle)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- run: pip install black pydocstyle

# annotate each step with `if: always` to run all regardless
- name: Assert that code matches Black code style
if: always()
uses: psf/black@stable
- name: Lint with pydocstyle
if: always()
run: pydocstyle --convention=numpy ./
79 changes: 55 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,31 @@ To build and install the library:
| [`CMAKE_PREFIX_PATH`](https://cmake.org/cmake/help/latest/variable/CMAKE_PREFIX_PATH.html) | `</path/to/libTorch/>` | Location of Torch installation<sup>1</sup> |
| [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `</path/to/install/lib/at/>` | Location at which the library files should be installed. By default this is `/usr/local` |
| [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code|
| `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA<sup>2</sup> |

<sup>1</sup> _The path to the Torch installation needs to allow cmake to locate the relevant Torch cmake files.
If Torch has been [installed as libtorch](https://pytorch.org/cppdocs/installing.html)
then this should be the absolute path to the unzipped libtorch distribution.
If Torch has been installed as PyTorch in a python [venv (virtual environment)](https://docs.python.org/3/library/venv.html),
e.g. with `pip install torch`, then this should be `</path/to/venv/>lib/python<3.xx>/site-packages/torch/`._

<sup>2</sup> _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g.
`pip install torch --index-url https://download.pytorch.org/whl/cpu`
or
`pip install torch --index-url https://download.pytorch.org/whl/cu118`
(for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._


4. Make and install the code to the chosen location with:
```
make
make install
```
This will place the following directories at the install location:
* `CMAKE_INSTALL_PREFIX/include/` - contains header and mod files
* `CMAKE_INSTALL_PREFIX/lib64/` - contains `cmake` directory and `.so` files
* `CMAKE_INSTALL_PREFIX/lib/` - contains `cmake` directory and `.so` files
_Note: depending on your system and architecture `lib` may be `lib64`, and
you may have `.dll` files or similar._


## Usage
Expand All @@ -119,11 +130,10 @@ To use the trained Torch model from within Fortran we need to import the `ftorch
A very simple example is given below.
For more detailed documentation please consult the API documentation, source code, and examples.

This minimal snippet loads a saved Torch model, creates inputs consisting of two `10x10` matrices (one of ones, and one of zeros), and runs the model to infer output.
This minimal snippet loads a saved Torch model, creates an input consisting of a `10x10` matrix of ones, and runs the model to infer output.
This is illustrative only, and we recommend following the [examples](examples/) before writing your own code to explore more features.

```fortran
! Import any C bindings as required for this code
use, intrinsic :: iso_c_binding, only: c_int, c_int64_t, c_loc
! Import library for interfacing with PyTorch
use ftorch
Expand All @@ -132,34 +142,32 @@ implicit none
! Generate an object to hold the Torch model
type(torch_module) :: model
! Set up types of input and output data and the interface with C
integer(c_int), parameter :: dims_input = 2
integer(c_int64_t) :: shape_input(dims_input)
integer(c_int), parameter :: n_inputs = 2
! Set up types of input and output data
integer, parameter :: n_inputs = 1
type(torch_tensor), dimension(n_inputs) :: model_input_arr
integer(c_int), parameter :: dims_output = 1
integer(c_int64_t) :: shape_output(dims_output)
type(torch_tensor) :: model_output
! Set up the model inputs as Fortran arrays
real, dimension(10,10), target :: input_1, input_2
! Set up the model input and output as Fortran arrays
real, dimension(10,10), target :: input
real, dimension(5), target :: output
! Set up number of dimensions of input tensor and axis order
integer, parameter :: in_dims = 2
integer :: in_layout(in_dims) = [1,2]
integer, parameter :: out_dims = 1
integer :: out_layout(out_dims) = [1]
! Initialise the Torch model to be used
model = torch_module_load("/path/to/saved/model.pt")
! Initialise the inputs as Fortran
input_1 = 0.0
input_2 = 1.0
! Initialise the inputs as Fortran array of ones
input = 1.0
! Wrap Fortran data as no-copy Torch Tensors
! There may well be some reshaping required depending on the
! structure of the model which is not covered here (see examples)
shape_input = (/10, 10/)
shape_output = (/5/)
model_input_arr(1) = torch_tensor_from_blob(c_loc(input_1), dims_input, shape_input, torch_kFloat64, torch_kCPU)
model_input_arr(2) = torch_tensor_from_blob(c_loc(input_2), dims_input, shape_input, torch_kFloat64, torch_kCPU)
model_output = torch_tensor_from_blob(c_loc(output), dims_output, shape_output, torch_kFloat64, torch_kCPU)
model_input_arr(1) = torch_tensor_from_array(input, in_layout, torch_kCPU)
model_output = torch_tensor_from_array(output, out_layout, torch_kCPU)
! Run model and Infer
! Again, there may be some reshaping required depending on model design
Expand All @@ -171,7 +179,6 @@ write(*,*) output
! Clean up
call torch_module_delete(model)
call torch_tensor_delete(model_input_arr(1))
call torch_tensor_delete(model_input_arr(2))
call torch_tensor_delete(model_output)
```

Expand All @@ -189,7 +196,10 @@ find_package(FTorch)
target_link_libraries( <executable> PRIVATE FTorch::ftorch )
message(STATUS "Building with Fortran PyTorch coupling")
```
and using the `-DFTorch_DIR=</path/to/install/location>` flag when running cmake.
and using the `-DCMAKE_PREFIX_PATH=</path/to/install/location>` flag when running cmake.
_Note: If you used the `CMAKE_INSTALL_PREFIX` argument when
[building and installing the library](#library-installation) above then you should use
the same path for `</path/to/install/location>`._

#### Make
To build with make we need to include the library when compiling and link the executable
Expand All @@ -203,14 +213,35 @@ FCFLAGS += -I<path/to/install/location>/include/ftorch

When compiling the final executable add the following link flag:
```
LDFLAGS += -L<path/to/install/location>/lib64 -lftorch
LDFLAGS += -L<path/to/install/location>/lib -lftorch
```

You may also need to add the location of the `.so` files to your `LD_LIBRARY_PATH`
unless installing in a default location:
```
export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:<path/to/installation>/lib64
export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:<path/to/install/location>/lib
```
_Note: depending on your system and architecture `lib` may be `lib64` or something similar._

### 4. Running on GPUs

In order to run a model on GPU, two main changes are required:

1. When saving your TorchScript model, ensure that it is on the GPU. For example, when using [pt2ts.py](utils/pt2ts.py), this can be done by uncommenting the following lines:

```
device = torch.device('cuda')
trained_model = trained_model.to(device)
trained_model.eval()
trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)
```

Note: this also moves the dummy input tensors to the GPU. This is not necessary for saving the model, but the tensors must also be on the GPU to test that the models runs.


2. When calling `torch_tensor_from_blob` in Fortran, the device for the input tensor(s), but not the output tensor(s),
should be set to `torch_kCUDA`, rather than `torch_kCPU`. This ensures that the inputs are on the same device as the model.


## Examples
Expand Down
3 changes: 1 addition & 2 deletions examples/1_SimpleNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ This can be done using the included `CMakeLists.txt` as follows:
```
mkdir build
cd build
cmake .. -DFTorch_DIR=<path/to/your/installation/of/library/>lib/cmake/ -DCMAKE_BUILD_TYPE=Release
cmake .. -DCMAKE_PREFIX_PATH=<path/to/your/installation/of/library/> -DCMAKE_BUILD_TYPE=Release
make
```
Make sure that the `FTorch_DIR` flag points to the `lib/cmake/` folder within the installation of the FTorch library.

To run the compiled code calling the saved SimpleNet TorchScript from Fortran run the
executable with an argument of the saved model file:
Expand Down
4 changes: 1 addition & 3 deletions examples/1_SimpleNet/simplenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module defining a simple PyTorch 'Net' for coupling to Fortran. """
"""Module defining a simple PyTorch 'Net' for coupling to Fortran."""
import torch
from torch import nn

Expand All @@ -15,7 +15,6 @@ def __init__(
Consists of a single Linear layer with weights predefined to
multiply the input by 2.
"""

super().__init__()
self._fwd_seq = nn.Sequential(
nn.Linear(5, 5, bias=False),
Expand All @@ -38,7 +37,6 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
batch scaled by 2.
"""

return self._fwd_seq(batch)


Expand Down
Binary file added examples/1_SimpleNet/simplenet_infer_fortran
Binary file not shown.
40 changes: 18 additions & 22 deletions examples/1_SimpleNet/simplenet_infer_fortran.f90
Original file line number Diff line number Diff line change
@@ -1,47 +1,45 @@
program inference

! Imports primitives used to interface with C
use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc
! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32

! Import our library for interfacing with PyTorch
use ftorch

implicit none


! Set precision for reals
integer, parameter :: wp = sp

integer :: num_args, ix
character(len=128), dimension(:), allocatable :: args

! Set up types of input and output data and the interface with C
! Set up Fortran data structures
real(wp), dimension(5), target :: in_data
real(wp), dimension(5), target :: out_data
integer, parameter :: n_inputs = 1
integer :: tensor_layout(1) = [1]

! Set up Torch data structures
type(torch_module) :: model
type(torch_tensor), dimension(1) :: in_tensor
type(torch_tensor) :: out_tensor

real(c_float), dimension(:), allocatable, target :: in_data
integer(c_int), parameter :: n_inputs = 1
real(c_float), dimension(:), allocatable, target :: out_data

integer(c_int), parameter :: tensor_dims = 1
integer(c_int64_t) :: tensor_shape(tensor_dims) = [5]
integer(c_int) :: tensor_layout(tensor_dims) = [1]

! Get TorchScript model file as a command line argument
num_args = command_argument_count()
allocate(args(num_args))
do ix = 1, num_args
call get_command_argument(ix,args(ix))
end do

! Allocate one-dimensional input/output arrays, based on multiplication of all input/output dimension sizes
allocate(in_data(tensor_shape(1)))
allocate(out_data(tensor_shape(1)))

! Initialise data
in_data = [0.0, 1.0, 2.0, 3.0, 4.0]

! Create input/output tensors from the above arrays
in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
out_tensor = torch_tensor_from_blob(c_loc(out_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
! Create Torch input/output tensors from the above arrays
in_tensor(1) = torch_tensor_from_array(in_data, tensor_layout, torch_kCPU)
out_tensor = torch_tensor_from_array(out_data, tensor_layout, torch_kCPU)

! Load ML model (edit this line to use different models)
! Load ML model
model = torch_module_load(args(1))

! Infer
Expand All @@ -52,7 +50,5 @@ program inference
call torch_module_delete(model)
call torch_tensor_delete(in_tensor(1))
call torch_tensor_delete(out_tensor)
deallocate(in_data)
deallocate(out_data)

end program inference
Loading

0 comments on commit 4f667bb

Please sign in to comment.