-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4e1c6c6
commit 51ba44c
Showing
231 changed files
with
289,028 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
_build | ||
*.install | ||
*.merlin | ||
_opam | ||
|
||
__pycache__ | ||
*.annot | ||
*.cmo | ||
*.cma | ||
*.cmi | ||
*.a | ||
*.o | ||
*.cmx | ||
*.cmxs | ||
*.cmxa | ||
*.swp | ||
*.ipynb_checkpoints | ||
|
||
data/ | ||
|
||
*.byte | ||
*.native | ||
|
||
setup.data | ||
setup.log | ||
torch.install | ||
|
||
# generated files | ||
src/wrapper/cxx_flags.sexp | ||
src/wrapper/c_library_flags.sexp | ||
src/wrapper/torch_stubs.c | ||
src/wrapper/torch_bindings.ml | ||
src/wrapper/torch_generated.ml | ||
src/wrapper/torch_bindings_generated.ml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
profile = janestreet | ||
let-binding-spacing = compact |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
This repository contains open source software that is developed and | ||
maintained by [Jane Street][js]. | ||
|
||
Contributions to this project are welcome and should be submitted via | ||
GitHub pull requests. | ||
|
||
Signing contributions | ||
--------------------- | ||
|
||
We require that you sign your contributions. Your signature certifies | ||
that you wrote the patch or otherwise have the right to pass it on as | ||
an open-source patch. The rules are pretty simple: if you can certify | ||
the below (from [developercertificate.org][dco]): | ||
|
||
``` | ||
Developer Certificate of Origin | ||
Version 1.1 | ||
Copyright (C) 2004, 2006 The Linux Foundation and its contributors. | ||
1 Letterman Drive | ||
Suite D4700 | ||
San Francisco, CA, 94129 | ||
Everyone is permitted to copy and distribute verbatim copies of this | ||
license document, but changing it is not allowed. | ||
Developer's Certificate of Origin 1.1 | ||
By making a contribution to this project, I certify that: | ||
(a) The contribution was created in whole or in part by me and I | ||
have the right to submit it under the open source license | ||
indicated in the file; or | ||
(b) The contribution is based upon previous work that, to the best | ||
of my knowledge, is covered under an appropriate open source | ||
license and I have the right under that license to submit that | ||
work with modifications, whether created in whole or in part | ||
by me, under the same open source license (unless I am | ||
permitted to submit under a different license), as indicated | ||
in the file; or | ||
(c) The contribution was provided directly to me by some other | ||
person who certified (a), (b) or (c) and I have not modified | ||
it. | ||
(d) I understand and agree that this project and the contribution | ||
are public and that a record of the contribution (including all | ||
personal information I submit with it, including my sign-off) is | ||
maintained indefinitely and may be redistributed consistent with | ||
this project or the open source license(s) involved. | ||
``` | ||
|
||
Then you just add a line to every git commit message: | ||
|
||
``` | ||
Signed-off-by: Joe Smith <[email protected]> | ||
``` | ||
|
||
Use your real name (sorry, no pseudonyms or anonymous contributions.) | ||
|
||
If you set your `user.name` and `user.email` git configs, you can sign | ||
your commit automatically with git commit -s. | ||
|
||
[dco]: http://developercertificate.org/ | ||
[js]: https://opensource.janestreet.com/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
The MIT License | ||
|
||
Copyright (c) 2022--2023 Jane Street Group, LLC <[email protected]> | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Some minimal contributions to ocaml-torch, originally contributed under the Apache 2.0 license, have been sublicensed under the MIT license, such that the whole of ocaml-torch is now licensed or sublicensed under the MIT license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
INSTALL_ARGS := $(if $(PREFIX),--prefix $(PREFIX),) | ||
|
||
default: | ||
dune build | ||
|
||
install: | ||
dune install $(INSTALL_ARGS) | ||
|
||
uninstall: | ||
dune uninstall $(INSTALL_ARGS) | ||
|
||
reinstall: uninstall install | ||
|
||
clean: | ||
dune clean | ||
|
||
.PHONY: default install uninstall reinstall clean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,152 @@ | ||
# torch | ||
# ocaml-torch | ||
__ocaml-torch__ provides some ocaml bindings for the [PyTorch](https://pytorch.org) tensor library. | ||
This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic | ||
differentiation. | ||
|
||
These bindings use the [PyTorch C++ API](https://pytorch.org/cppdocs/) and are | ||
mostly automatically generated. The current GitHub tip and the opam package v0.7 | ||
corresponds to PyTorch **v1.13.0**. | ||
|
||
On Linux note that you will need the PyTorch version using the appropriate cxx11 abi depending on your g++ version. | ||
[cpu version](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip), | ||
[cuda 11.6 version](https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu116.zip). | ||
|
||
## Opam Installation | ||
|
||
The [opam](https://opam.ocaml.org/) package can be installed using the following command. | ||
This automatically installs the CPU version of libtorch. | ||
|
||
```bash | ||
opam install torch | ||
``` | ||
|
||
You can then compile some sample code, see some instructions below. | ||
__ocaml-torch__ can also be used in interactive mode via | ||
[utop](https://github.com/ocaml-community/utop) or | ||
[ocaml-jupyter](https://github.com/akabe/ocaml-jupyter). | ||
|
||
Here is a sample utop session. | ||
|
||
 | ||
|
||
|
||
### Build a Simple Example | ||
|
||
To build a first torch program, create a file `example.ml` with the | ||
following content. | ||
|
||
```ocaml | ||
open Torch | ||
let () = | ||
let tensor = Tensor.randn [ 4; 2 ] in | ||
Tensor.print tensor | ||
``` | ||
|
||
Then create a `dune` file with the following content: | ||
|
||
```ocaml | ||
(executables | ||
(names example) | ||
(libraries torch)) | ||
``` | ||
|
||
Run `dune exec example.exe` to compile the program and run it! | ||
|
||
Alternatively you can first compile the code via `dune build example.exe` then run the executable | ||
`_build/default/example.exe` (note that building the bytecode target `example.bc` may | ||
not work on macos). | ||
|
||
## Tutorials and Examples | ||
|
||
* [MNIST tutorial](./examples/mnist/README.md). | ||
* [Finetuning a ResNet-18 model](./examples/pretrained/README.md). | ||
* [Generative Adversarial Networks](./examples/gan/README.md). | ||
* [Running some Python model](./examples/jit/README.md). | ||
|
||
Some more advanced applications from external repos: | ||
* An [OCaml port of mini-dalle](https://github.com/ArulselvanMadhavan/mini_dalle) by Arulselvan Madhavan. | ||
* Natural Language Processing models based on BERT can be found in the | ||
[ocaml-bert repo](https://github.com/LaurentMazare/ocaml-bert). | ||
|
||
## Sample Code | ||
|
||
Below is an example of a linear model trained on the MNIST dataset ([full | ||
code](./examples/mnist/README.md)). | ||
|
||
```ocaml | ||
(* Create two tensors to store model weights. *) | ||
let ws = Tensor.zeros [image_dim; label_count] ~requires_grad:true in | ||
let bs = Tensor.zeros [label_count] ~requires_grad:true in | ||
let model xs = Tensor.(mm xs ws + bs) in | ||
for index = 1 to 100 do | ||
(* Compute the cross-entropy loss. *) | ||
let loss = | ||
Tensor.cross_entropy_for_logits (model train_images) ~targets:train_labels | ||
in | ||
Tensor.backward loss; | ||
(* Apply gradient descent, disable gradient tracking for these. *) | ||
Tensor.(no_grad (fun () -> | ||
ws -= grad ws * f learning_rate; | ||
bs -= grad bs * f learning_rate)); | ||
(* Compute the validation error. *) | ||
let test_accuracy = | ||
Tensor.(argmax ~dim:(-1) (model test_images) = test_labels) | ||
|> Tensor.to_kind ~kind:(T Float) | ||
|> Tensor.sum | ||
|> Tensor.float_value | ||
|> fun sum -> sum /. test_samples | ||
in | ||
printf "%d %f %.2f%%\n%!" index (Tensor.float_value loss) (100. *. test_accuracy); | ||
done | ||
``` | ||
|
||
* Some [ResNet examples on CIFAR-10](./examples/cifar/README.md). | ||
* A simplified version of | ||
[char-rnn](./examples/char_rnn/README.md) | ||
illustrating character level language modeling using Recurrent Neural Networks. | ||
* [Neural Style Transfer](./examples/neural_transfer/README.md) | ||
applies the style of an image to the content of another image. This uses some deep Convolutional Neural Network. | ||
|
||
## Models and Weights | ||
|
||
Various pre-trained computer vision models are implemented in the vision library. | ||
The weight files can be downloaded at the following links: | ||
|
||
|
||
* ResNet-18 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet18.ot). | ||
* ResNet-34 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet34.ot). | ||
* ResNet-50 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet50.ot). | ||
* ResNet-101 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet101.ot). | ||
* ResNet-152 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet152.ot). | ||
* DenseNet-121 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/densenet121.ot). | ||
* DenseNet-161 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/densenet161.ot). | ||
* DenseNet-169 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/densenet169.ot). | ||
* SqueezeNet 1.0 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/squeezenet1_0.ot). | ||
* SqueezeNet 1.1 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/squeezenet1_1.ot). | ||
* VGG-13 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/vgg13.ot). | ||
* VGG-16 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/vgg16.ot). | ||
* AlexNet [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/alexnet.ot). | ||
* Inception-v3 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/inception-v3.ot). | ||
* MobileNet-v2 [weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/mobilenet-v2.ot). | ||
* EfficientNet | ||
[b0 weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/efficientnet-b0.ot), | ||
[b1 weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/efficientnet-b1.ot), | ||
[b2 weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/efficientnet-b2.ot), | ||
[b3 weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/efficientnet-b3.ot), | ||
[b4 weights](https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/efficientnet-b4.ot). | ||
|
||
Running the pre-trained models on some sample images can the easily be done via the following commands. | ||
```bash | ||
dune exec examples/pretrained/predict.exe path/to/resnet18.ot images/tiger.jpg | ||
``` | ||
|
||
## Acknowledgements | ||
|
||
Many thanks to [@LaurentMazare](https://github.com/LaurentMazare) for the [original | ||
work](https://github.com/LaurentMazare/ocaml-torch) of ocaml-torch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
(executables (names tensor_tools) | ||
(libraries base cmdliner npy torch_core stdio torch torch_vision) | ||
(preprocess (pps ppx_jane))) |
Oops, something went wrong.