Skip to content

Commit

Permalink
decouple torch/types.h
Browse files Browse the repository at this point in the history
  • Loading branch information
hugary1995 committed Feb 8, 2025
1 parent 6c06ff7 commit 27fb0c5
Show file tree
Hide file tree
Showing 154 changed files with 1,121 additions and 934 deletions.
2 changes: 1 addition & 1 deletion doc/content/system/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ While the convergence criteria are defined by the specific solvers derived from
bool
MySolver::converged(const torch::Tensor & nR, const torch::Tensor & nR0) const
{
return torch::all(torch::logical_or(nR < atol, nR / nR0 < rtol)).item<bool>();
return at::all(at::logical_or(nR < atol, nR / nR0 < rtol)).item<bool>();
}
```
where `nR` is the vector norm of the current residual, and `nR0` is the vector norm of the initial residual (evaluated at the initial guess). The above statement makes sure the current residual is either below the absolute tolerance or has been sufficiently reduced, and the condition is applied to _all_ batches of the residual norm.
Expand Down
8 changes: 4 additions & 4 deletions doc/content/system/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ neml2::Tensor is a general-purpose *dynamically shaped* tensor type for batched
A `Tensor` can be created from a `torch::Tensor` and a batch dimension:
```cpp
Tensor A(torch::rand({1, 1, 5, 2}), 2);
Tensor A(at::rand({1, 1, 5, 2}), 2);
```
The batch sizes of `A` is `(1, 1)`:
```cpp
Expand Down Expand Up @@ -72,9 +72,9 @@ For example, the following code
```cpp
auto a = SR2::zeros({5, 3},
torch::TensorOptions()
.device(torch::kCPU)
.device(kCPU)
.layout(torch::kStrided)
.dtype(torch::kFloat32));
.dtype(kFloat32));
```
creates a statically (base) shaped, dense, single precision tensor of type `SR2` filled with zeros, with batch shape \f$(5, 3)\f$, allocated on the CPU.
Expand Down Expand Up @@ -123,7 +123,7 @@ A.base_index_put_({Slice(1, 3)}, torch::ones({3, 2}));
// A = [[ 2 1 1]
// [ -1 1 1]
// [ 6 1 1]]
A.batch_index_put_({Slice(0, 2)}, torch::zeros({2, 3}));
A.batch_index_put_({Slice(0, 2)}, at::zeros({2, 3}));
// A = [[ 0 0 0]
// [ 0 0 0]
// [ 6 1 1]]
Expand Down
2 changes: 1 addition & 1 deletion include/neml2/dispatchers/StaticHybridScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class StaticHybridScheduler : public WorkScheduler
* The constructor takes a device list, along with the batch sizes, capacities, and priorities for
* each device.
*
* The device list should be unique and non-empty. torch::kCPU can appear at most once. When
* The device list should be unique and non-empty. kCPU can appear at most once. When
* multiple cuda devices are present, each of them must correspond to a specific device ID.
*
* One or more batch size should be provided. If the number of batch sizes is one, the same batch
Expand Down
4 changes: 2 additions & 2 deletions include/neml2/dispatchers/WorkDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ WorkDispatcher<I, O, Of, Ip, Op>::run(WorkGenerator<Ip> & generator,
{
validate();

torch::Device device = torch::kCPU;
torch::Device device = kCPU;
std::size_t n = 0;
std::vector<Op> results;
while (generator.has_more())
Expand Down Expand Up @@ -201,7 +201,7 @@ WorkDispatcher<I, O, Of, Ip, Op>::run_async(WorkGenerator<Ip> & generator,
{
validate();

torch::Device device = torch::kCPU;
torch::Device device = kCPU;
std::size_t n = 0;
using FutureResult = std::tuple<std::size_t, torch::Device, std::size_t, std::future<O>>;
std::vector<FutureResult> future_results;
Expand Down
2 changes: 1 addition & 1 deletion include/neml2/misc/defaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace neml2
/**
* @name RAII style default tensor options
*
* The factory methods like `torch::arange`, `torch::ones`, `torch::zeros`, `torch::rand` etc.
* The factory methods like `at::arange`, `torch::ones`, `at::zeros`, `at::rand` etc.
* accept a common argument to configure the properties of the tensor being created. We predefine
* a default tensor configuration in NEML2. This default configuration is consistently used
* throughout NEML2.
Expand Down
35 changes: 35 additions & 0 deletions include/neml2/misc/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include <iosfwd>
#include <limits>

#include <c10/core/ScalarType.h>
#include <c10/core/DeviceType.h>

namespace c10
{
template <typename T, unsigned N>
Expand Down Expand Up @@ -64,6 +67,38 @@ using Device = c10::Device;

namespace neml2
{
/// Fixed width dtypes (mirroring the definition in <torch/csrc/api/include/torch/types.h>)
constexpr auto kUInt8 = c10::kByte;
constexpr auto kInt8 = c10::kChar;
constexpr auto kInt16 = c10::kShort;
constexpr auto kInt32 = c10::kInt;
constexpr auto kInt64 = c10::kLong;
constexpr auto kUInt16 = c10::kUInt16;
constexpr auto kUInt32 = c10::kUInt32;
constexpr auto kUInt64 = c10::kUInt64;
constexpr auto kFloat16 = c10::kHalf;
constexpr auto kFloat32 = c10::kFloat;
constexpr auto kFloat64 = c10::kDouble;

// Device types (mirroring the definition in <c10/core/DeviceType.h>)
constexpr auto kCPU = c10::DeviceType::CPU;
constexpr auto kCUDA = c10::DeviceType::CUDA;
constexpr auto kHIP = c10::DeviceType::HIP;
constexpr auto kFPGA = c10::DeviceType::FPGA;
constexpr auto kMAIA = c10::DeviceType::MAIA;
constexpr auto kXLA = c10::DeviceType::XLA;
constexpr auto kMPS = c10::DeviceType::MPS;
constexpr auto kMeta = c10::DeviceType::Meta;
constexpr auto kVulkan = c10::DeviceType::Vulkan;
constexpr auto kMetal = c10::DeviceType::Metal;
constexpr auto kXPU = c10::DeviceType::XPU;
constexpr auto kHPU = c10::DeviceType::HPU;
constexpr auto kVE = c10::DeviceType::VE;
constexpr auto kLazy = c10::DeviceType::Lazy;
constexpr auto kIPU = c10::DeviceType::IPU;
constexpr auto kMTIA = c10::DeviceType::MTIA;
constexpr auto kPrivateUse1 = c10::DeviceType::PrivateUse1;

using Real = double;
using Size = int64_t;
using Integer = int64_t;
Expand Down
2 changes: 1 addition & 1 deletion include/neml2/models/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Model : public std::enable_shared_from_this<Model>,
struct TraceSchema
{
std::vector<Size> batch_dims;
torch::DispatchKey dispatch_key;
at::DispatchKey dispatch_key;
bool operator==(const TraceSchema & other) const;
bool operator<(const TraceSchema & other) const;
};
Expand Down
4 changes: 2 additions & 2 deletions include/neml2/models/crystallography/CrystalGeometry.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ CrystalGeometry::slip_slice(const Derived & tensor, Size grp) const
{
if (grp >= nslip_groups())
throw NEMLException("Invalid slip group index");
return tensor.batch_index({torch::indexing::Ellipsis,
torch::indexing::Slice(_slip_offsets[grp], _slip_offsets[grp + 1])});
return tensor.batch_index(
{indexing::Ellipsis, indexing::Slice(_slip_offsets[grp], _slip_offsets[grp + 1])});
}

} // namespace crystallography
Expand Down
2 changes: 1 addition & 1 deletion include/neml2/tensors/Rot.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Rot : public VecBase<Rot>
identity(const torch::TensorOptions & options = default_tensor_options());

/// Fill from an array of Euler angles
static Rot fill_euler_angles(const torch::Tensor & vals,
static Rot fill_euler_angles(const Vec & v,
const std::string & angle_convention,
const std::string & angle_type);

Expand Down
4 changes: 2 additions & 2 deletions include/neml2/tensors/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#pragma once

#include <torch/types.h>
#include <ATen/core/Tensor.h>
#include "neml2/jit/TraceableTensorShape.h"
#include "neml2/tensors/shape_utils.h"
#include "neml2/tensors/functions/operators.h"
Expand Down Expand Up @@ -101,7 +101,7 @@ class TensorBase : public torch::Tensor
/// @name Meta operations
///@{
/// Clone (take ownership)
Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
Derived clone() const;
/// Discard function graph
Derived detach() const;
/// Detach from gradient graphs in place
Expand Down
28 changes: 14 additions & 14 deletions include/neml2/tensors/TensorBaseImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,28 @@ template <class Derived>
Derived
TensorBase<Derived>::empty_like(const Derived & other)
{
return Derived(torch::empty_like(other), other.batch_sizes());
return Derived(at::empty_like(other), other.batch_sizes());
}

template <class Derived>
Derived
TensorBase<Derived>::zeros_like(const Derived & other)
{
return Derived(torch::zeros_like(other), other.batch_sizes());
return Derived(at::zeros_like(other), other.batch_sizes());
}

template <class Derived>
Derived
TensorBase<Derived>::ones_like(const Derived & other)
{
return Derived(torch::ones_like(other), other.batch_sizes());
return Derived(at::ones_like(other), other.batch_sizes());
}

template <class Derived>
Derived
TensorBase<Derived>::full_like(const Derived & other, Real init)
{
return Derived(torch::full_like(other, init), other.batch_sizes());
return Derived(at::full_like(other, init), other.batch_sizes());
}

template <class Derived>
Expand All @@ -116,7 +116,7 @@ TensorBase<Derived>::linspace(const Derived & start, const Derived & end, Size n
indexing::TensorIndices net(dim, indexing::None);
net.push_back(indexing::Ellipsis);
net.insert(net.end(), Bd - dim, indexing::None);
Scalar steps(torch::arange(nstep, diff.options()).index(net) / (nstep - 1));
Scalar steps(at::arange(nstep, diff.options()).index(net) / (nstep - 1));

res = res + steps * diff;
}
Expand All @@ -130,14 +130,14 @@ TensorBase<Derived>::logspace(
const Derived & start, const Derived & end, Size nstep, Size dim, Real base)
{
auto exponent = neml2::Tensor::linspace(start, end, nstep, dim);
return Derived(torch::pow(base, exponent), exponent.batch_sizes());
return Derived(at::pow(base, exponent), exponent.batch_sizes());
}

template <class Derived>
Derived
TensorBase<Derived>::clone(torch::MemoryFormat memory_format) const
TensorBase<Derived>::clone() const
{
return Derived(torch::Tensor::clone(memory_format), batch_sizes());
return Derived(torch::Tensor::clone(), batch_sizes());
}

template <class Derived>
Expand Down Expand Up @@ -221,7 +221,7 @@ Derived
TensorBase<Derived>::batch_index(indexing::TensorIndicesRef indices) const
{
indexing::TensorIndices indices_vec(indices);
indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
auto res = this->index(indices_vec);
return Derived(res, res.dim() - base_dim());
}
Expand All @@ -230,7 +230,7 @@ template <class Derived>
neml2::Tensor
TensorBase<Derived>::base_index(indexing::TensorIndicesRef indices) const
{
indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
indices2.insert(indices2.end(), indices.begin(), indices.end());
return neml2::Tensor(this->index(indices2), batch_sizes());
}
Expand Down Expand Up @@ -261,7 +261,7 @@ TensorBase<Derived>::batch_index_put_(indexing::TensorIndicesRef indices,
const torch::Tensor & other)
{
indexing::TensorIndices indices_vec(indices);
indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
this->index_put_(indices_vec, other);
}

Expand All @@ -270,7 +270,7 @@ void
TensorBase<Derived>::batch_index_put_(indexing::TensorIndicesRef indices, Real v)
{
indexing::TensorIndices indices_vec(indices);
indices_vec.insert(indices_vec.end(), base_dim(), torch::indexing::Slice());
indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
this->index_put_(indices_vec, v);
}

Expand All @@ -279,7 +279,7 @@ void
TensorBase<Derived>::base_index_put_(indexing::TensorIndicesRef indices,
const torch::Tensor & other)
{
indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
indices2.insert(indices2.end(), indices.begin(), indices.end());
this->index_put_(indices2, other);
}
Expand All @@ -288,7 +288,7 @@ template <class Derived>
void
TensorBase<Derived>::base_index_put_(indexing::TensorIndicesRef indices, Real v)
{
indexing::TensorIndices indices2(batch_dim(), torch::indexing::Slice());
indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
indices2.insert(indices2.end(), indices.begin(), indices.end());
this->index_put_(indices2, v);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace neml2
{
#define DECLARE_ARCCOS(T) T arccos(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ARCCOS);
#undef DECLARE_ARCCOS
#define DECLARE_ACOS(T) T acos(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ACOS);
#undef DECLARE_ACOS
} // namespace neml2
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace neml2
{
#define DECLARE_ARCSIN(T) T arcsin(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ARCSIN);
#undef DECLARE_ARCSIN
#define DECLARE_ASIN(T) T asin(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ASIN);
#undef DECLARE_ASIN
} // namespace neml2
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace neml2
{
#define DECLARE_ARCTAN(T) T arctan(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ARCTAN);
#undef DECLARE_ARCTAN
#define DECLARE_ATAN(T) T atan(const T & a)
FOR_ALL_TENSORBASE(DECLARE_ATAN);
#undef DECLARE_ATAN
} // namespace neml2
34 changes: 34 additions & 0 deletions include/neml2/tensors/functions/deg2rad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024, UChicago Argonne, LLC
// All Rights Reserved
// Software Name: NEML2 -- the New Engineering material Model Library, version 2
// By: Argonne National Laboratory
// OPEN SOURCE LICENSE (MIT)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#pragma once

#include "neml2/tensors/tensors_fwd.h"

namespace neml2
{
#define DECLARE_DEG2RAD(T) T deg2rad(const T & a)
FOR_ALL_TENSORBASE(DECLARE_DEG2RAD);
#undef DECLARE_DEG2RAD
} // namespace neml2
37 changes: 37 additions & 0 deletions include/neml2/tensors/functions/fmod.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024, UChicago Argonne, LLC
// All Rights Reserved
// Software Name: NEML2 -- the New Engineering material Model Library, version 2
// By: Argonne National Laboratory
// OPEN SOURCE LICENSE (MIT)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#pragma once

#include "neml2/misc/types.h"
#include "neml2/tensors/tensors_fwd.h"

namespace neml2
{
#define DECLARE_FMOD(T) \
T fmod(const T & a, const T & b); \
T fmod(const T & a, const Real & b)
FOR_ALL_TENSORBASE(DECLARE_FMOD);
#undef DECLARE_FMOD
} // namespace neml2
2 changes: 2 additions & 0 deletions include/neml2/tensors/functions/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#pragma once

#include <ATen/TensorOperators.h>

#include "neml2/misc/types.h"
#include "neml2/tensors/macros.h"

Expand Down
Loading

0 comments on commit 27fb0c5

Please sign in to comment.