Skip to content

Commit

Permalink
cleanup c10 and at namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
hugary1995 committed Feb 9, 2025
1 parent e4f4aa1 commit 31022c8
Show file tree
Hide file tree
Showing 100 changed files with 436 additions and 500 deletions.
2 changes: 1 addition & 1 deletion doc/content/system/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The first argument is the nonlinear system of equations to be solved, and the se
While the convergence criteria are defined by the specific solvers derived from the base class, it is generally recommended to use both `atol` and `rtol` in the convergence check. Below is an example convergence criteria
```cpp
bool
MySolver::converged(const torch::Tensor & nR, const torch::Tensor & nR0) const
MySolver::converged(const at::Tensor & nR, const at::Tensor & nR0) const
{
return at::all(at::logical_or(nR < atol, nR / nR0 < rtol)).item<bool>();
}
Expand Down
10 changes: 5 additions & 5 deletions doc/content/system/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ Refer to [Syntax Documentation](@ref syntax-tensors) for the list of available o

## Tensor types {#tensor-types}

Currently, PyTorch is the only supported tensor backend in NEML2. Therefore, all tensor types in NEML2 directly inherit from `torch::Tensor`. In the future, support for other tensor backends may be added, but the public-facing interfaces will remain largely the same.
Currently, PyTorch is the only supported tensor backend in NEML2. Therefore, all tensor types in NEML2 directly inherit from `at::Tensor`. In the future, support for other tensor backends may be added, but the public-facing interfaces will remain largely the same.

### Dynamically shaped tensor {#dynamically-shaped-tensor}

neml2::Tensor is a general-purpose *dynamically shaped* tensor type for batched tensors. With a view towards vectorization, the same set of operations can be "simultaneously" applied to a "batch" of tensors. To provide a unified user interface for dealing with such batched operation, NEML2 assumes that the *first* \f$N\f$ dimensions of a tensor are batched dimensions, and the following dimensions are the base dimensions.

> Unlike PyTorch, NEML2 explicitly distinguishes between batch dimensions and base dimensions.
A `Tensor` can be created from a `torch::Tensor` and a batch dimension:
A `Tensor` can be created from a `at::Tensor` and a batch dimension:
```cpp
Tensor A(at::rand({1, 1, 5, 2}), 2);
```
Expand Down Expand Up @@ -60,9 +60,9 @@ Furthermore, all primitive tensor types can be "registered" as variables on a `L
A factory tensor creation function produces a new tensor. All factory functions adhere to the same schema:
```cpp
<TensorType>::<function_name>(<function-specific-options>, const torch::TensorOptions & options);
<TensorType>::<function_name>(<function-specific-options>, const TensorOptions & options);
```
where `<TensorType>` is the class name of the primitive tensor type listed above, and `<function-name>` is the name of the factory function which produces the new tensor. `<function-specific-options>` are any required or optional arguments a particular factory function accepts. Refer to each tensor type's class documentation for the concrete signature. The last argument `const torch::TensorOptions & options` configures the data type, device, layout and other "meta" properties of the produced tensor. The commonly used meta properties are
where `<TensorType>` is the class name of the primitive tensor type listed above, and `<function-name>` is the name of the factory function which produces the new tensor. `<function-specific-options>` are any required or optional arguments a particular factory function accepts. Refer to each tensor type's class documentation for the concrete signature. The last argument `const TensorOptions & options` configures the data type, device, layout and other "meta" properties of the produced tensor. The commonly used meta properties are
- `dtype`: the data type of the elements stored in the tensor. Available options are `kUInt8`, `kInt8`, `kInt16`, `kInt32`, `kInt64`, `kFloat32`, and `kFloat64`.
- `layout`: the striding of the tensor. Available options are `kStrided` (dense) and `kSparse`.
- `device`: the compute device where the tensor will be allocated. Available options are `kCPU` and `kCUDA`.
Expand All @@ -71,7 +71,7 @@ where `<TensorType>` is the class name of the primitive tensor type listed above
For example, the following code
```cpp
auto a = SR2::zeros({5, 3},
torch::TensorOptions()
TensorOptions()
.device(kCPU)
.layout(torch::kStrided)
.dtype(kFloat32));
Expand Down
2 changes: 1 addition & 1 deletion doc/content/user/input_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The cross-referencing mechanism allows object options in the input file to _refe
```
In the above example, the object of type `ImplicitUpdate` references an implicit model named "implicit_rate" and a solver named "newton".

In addition to directly referencing objects by their names, a few special types support more flexible referencing mechanisms. `torch::Tensor`, `Tensor`, and all primitive tensor types with fixed base shapes such as `Scalar`, `SR2`, etc, can be referenced either by value or by name.
In addition to directly referencing objects by their names, a few special types support more flexible referencing mechanisms. `at::Tensor`, `Tensor`, and all primitive tensor types with fixed base shapes such as `Scalar`, `SR2`, etc, can be referenced either by value or by name.

When a tensor is referenced by value, the parser will parse the input option value as a numeric literal and return a tensor filled with the specified value; when a tensor is referenced by name, the parser will look for and return the object under the `[Tensors]` section with the given name.

Expand Down
12 changes: 5 additions & 7 deletions include/neml2/base/LabeledAxisAccessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
#include <vector>
#include <iosfwd>

#include <c10/util/SmallVector.h>

namespace neml2
{
// Reserved subaxis names
Expand Down Expand Up @@ -89,13 +87,13 @@ class LabeledAxisAccessor
std::string>>>
LabeledAxisAccessor(const Container & c)
{
_item_names.append(c.begin(), c.end());
_item_names = std::vector<std::string>(c.begin(), c.end());
for (const auto & name : _item_names)
validate_item_name(name);
}

using iterator = c10::SmallVector<std::string>::iterator;
using const_iterator = c10::SmallVector<std::string>::const_iterator;
using iterator = std::vector<std::string>::iterator;
using const_iterator = std::vector<std::string>::const_iterator;

/**
* @name Iterators
Expand All @@ -111,7 +109,7 @@ class LabeledAxisAccessor

explicit operator std::vector<std::string>() const;

const c10::SmallVector<std::string> & vec() const { return _item_names; }
const std::vector<std::string> & vec() const { return _item_names; }

std::string str() const;

Expand Down Expand Up @@ -162,7 +160,7 @@ class LabeledAxisAccessor
/// Throws if the item name has invalid format
void validate_item_name(const std::string &) const;

c10::SmallVector<std::string> _item_names;
std::vector<std::string> _item_names;
};

/// Compare for equality between two LabeledAxisAccessor
Expand Down
10 changes: 5 additions & 5 deletions include/neml2/dispatchers/SimpleScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ class SimpleScheduler : public WorkScheduler
* be simultaneously handled by the device at any given time. The default capacity is set to the
* maximum value of size_t
*/
SimpleScheduler(torch::Device device,
SimpleScheduler(Device device,
std::size_t batch_size,
std::size_t capacity = std::numeric_limits<std::size_t>::max());

bool schedule_work(torch::Device &, std::size_t &) const override;
bool schedule_work(Device &, std::size_t &) const override;

void dispatched_work(torch::Device, std::size_t) override;
void dispatched_work(Device, std::size_t) override;

void completed_work(torch::Device, std::size_t) override;
void completed_work(Device, std::size_t) override;

private:
/// The device to dispatch to
torch::Device _device;
Device _device;

/// The batch size to dispatch
std::size_t _batch_size;
Expand Down
19 changes: 8 additions & 11 deletions include/neml2/dispatchers/StaticHybridScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ class StaticHybridScheduler : public WorkScheduler
public:
struct DeviceStatus
{
DeviceStatus(torch::Device device,
std::size_t batch_size,
std::size_t capacity,
double priority)
DeviceStatus(Device device, std::size_t batch_size, std::size_t capacity, double priority)
: device(device),
batch_size(batch_size),
capacity(capacity),
Expand All @@ -53,7 +50,7 @@ class StaticHybridScheduler : public WorkScheduler
{
}

torch::Device device;
Device device;
std::size_t batch_size;
std::size_t capacity;
double priority;
Expand Down Expand Up @@ -81,8 +78,8 @@ class StaticHybridScheduler : public WorkScheduler
* this dispatcher chooses the device to dispatch not only based on the priority but also based on
* the availability of the device. See next() for more details.
*
* \note For developers, below is a summary of the construct of torch::Device:
* torch::Device represents a compute device on which a tensor is located. A device is uniquely
* \note For developers, below is a summary of the construct of Device:
* Device represents a compute device on which a tensor is located. A device is uniquely
* identified by a type, which specifies the type of machine it is (e.g. CPU or CUDA GPU), and a
* device index or ordinal, which identifies the specific compute device when there is more than
* one of a certain type. The device index is optional, and in its defaulted state represents
Expand All @@ -92,7 +89,7 @@ class StaticHybridScheduler : public WorkScheduler
* represents a specific, concrete device,
* 2. When the device type is CPU, the device index must be zero.
*/
StaticHybridScheduler(const std::vector<torch::Device> & device_list,
StaticHybridScheduler(const std::vector<Device> & device_list,
const std::vector<std::size_t> & batch_sizes,
const std::vector<std::size_t> & capacities = {},
const std::vector<double> & priorities = {});
Expand All @@ -108,14 +105,14 @@ class StaticHybridScheduler : public WorkScheduler
* By default, the availability is the device's priority, a custom function can be set using
* set_availability_calculator().
*/
bool schedule_work(torch::Device &, std::size_t &) const override;
bool schedule_work(Device &, std::size_t &) const override;

/// Set a custom availability calculator
void set_availability_calculator(std::function<double(const DeviceStatus &)>);

void dispatched_work(torch::Device, std::size_t) override;
void dispatched_work(Device, std::size_t) override;

void completed_work(torch::Device, std::size_t) override;
void completed_work(Device, std::size_t) override;

const std::vector<DeviceStatus> & status() const { return _devices; }

Expand Down
18 changes: 9 additions & 9 deletions include/neml2/dispatchers/WorkDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,21 @@ template <typename I,
class WorkDispatcher
{
public:
WorkDispatcher(std::function<O(I &&, torch::Device)> && dispatch)
WorkDispatcher(std::function<O(I &&, Device)> && dispatch)
: _dispatch(std::move(dispatch))
{
}

WorkDispatcher(std::function<O(I &&, torch::Device)> && dispatch,
WorkDispatcher(std::function<O(I &&, Device)> && dispatch,
std::function<O(std::vector<O> &&)> && reduce)
: _dispatch(std::move(dispatch)),
_reduce(std::move(reduce))
{
}

WorkDispatcher(std::function<O(I &&, torch::Device)> && dispatch,
WorkDispatcher(std::function<O(I &&, Device)> && dispatch,
std::function<Of(std::vector<Op> &&)> && reduce,
std::function<I(Ip &&, torch::Device)> && preprocess,
std::function<I(Ip &&, Device)> && preprocess,
std::function<Op(O &&)> && postprocess)
: _dispatch(std::move(dispatch)),
_reduce(std::move(reduce)),
Expand All @@ -108,13 +108,13 @@ class WorkDispatcher
void validate() const;

/// Function to dispatch preprocessed work to the worker and retrieve the result
std::function<O(Ip &&, torch::Device)> _dispatch;
std::function<O(Ip &&, Device)> _dispatch;

/// Function to reduce the results
std::function<Of(std::vector<Op> &&)> _reduce;

/// Function to preprocess the work
std::function<I(Ip &&, torch::Device)> _preprocess;
std::function<I(Ip &&, Device)> _preprocess;

/// Function to postprocess the result
std::function<Op(O &&)> _postprocess;
Expand Down Expand Up @@ -151,7 +151,7 @@ WorkDispatcher<I, O, Of, Ip, Op>::run(WorkGenerator<Ip> & generator,
{
validate();

torch::Device device = kCPU;
Device device = kCPU;
std::size_t n = 0;
std::vector<Op> results;
while (generator.has_more())
Expand Down Expand Up @@ -201,9 +201,9 @@ WorkDispatcher<I, O, Of, Ip, Op>::run_async(WorkGenerator<Ip> & generator,
{
validate();

torch::Device device = kCPU;
Device device = kCPU;
std::size_t n = 0;
using FutureResult = std::tuple<std::size_t, torch::Device, std::size_t, std::future<O>>;
using FutureResult = std::tuple<std::size_t, Device, std::size_t, std::future<O>>;
std::vector<FutureResult> future_results;
std::vector<Op> results;
while (generator.has_more())
Expand Down
6 changes: 3 additions & 3 deletions include/neml2/dispatchers/WorkScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ class WorkScheduler
* @return true If work has been scheduled, i.e., there is a worker available
* @return false If work cannot be scheduled, i.e., there is no worker available
*/
virtual bool schedule_work(torch::Device &, std::size_t &) const = 0;
virtual bool schedule_work(Device &, std::size_t &) const = 0;

/// Update the schedule with the dispatch of the last batch
virtual void dispatched_work(torch::Device, std::size_t) = 0;
virtual void dispatched_work(Device, std::size_t) = 0;

/// Update the schedule with the completion of the last batch
virtual void completed_work(torch::Device, std::size_t) = 0;
virtual void completed_work(Device, std::size_t) = 0;
};
} // namespace neml2
2 changes: 1 addition & 1 deletion include/neml2/drivers/TransientDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TransientDriver : public Driver
/// The model which the driver uses to perform constitutive updates.
Model & _model;
/// The device on which to evaluate the model
const torch::Device _device;
const Device _device;

/// VariableName for the time
const VariableName _time_name;
Expand Down
10 changes: 5 additions & 5 deletions include/neml2/jit/TraceableSize.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ namespace neml2
* Similar to neml2::TraceableTensorShape, but only for a single dimension.
* @see neml2::TraceableTensorShape
*/
struct TraceableSize : public std::variant<Size, torch::Tensor>
struct TraceableSize : public std::variant<Size, at::Tensor>
{
using std::variant<Size, torch::Tensor>::variant;
using std::variant<Size, at::Tensor>::variant;

/// @return a pointer to the torch::Tensor representing the traceable size if it is traceable, otherwise a nullptr
const torch::Tensor * traceable() const noexcept;
/// @return a pointer to the at::Tensor representing the traceable size if it is traceable, otherwise a nullptr
const at::Tensor * traceable() const noexcept;

/// @return the concrete size (without any traceable information)
Size concrete() const;

/// @return the size represented as a scalar tensor (possibly traceable)
torch::Tensor as_tensor() const;
at::Tensor as_tensor() const;
};

/// Comparison operators
Expand Down
8 changes: 4 additions & 4 deletions include/neml2/jit/TraceableTensorShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ namespace neml2
* A tensor shape can be either a concrete shape or a traceable tensor. This is useful when we need
* to trace a function graph and let it generalize to other batch shapes.
*/
struct TraceableTensorShape : public torch::SmallVector<TraceableSize, 8>
struct TraceableTensorShape : public SmallVector<TraceableSize, 8>
{
using torch::SmallVector<TraceableSize, 8>::SmallVector;
using SmallVector<TraceableSize, 8>::SmallVector;
using Size = int64_t;

TraceableTensorShape(const TensorShape & shape);
TraceableTensorShape(TensorShapeRef shape);
TraceableTensorShape(Size shape);
TraceableTensorShape(const torch::Tensor & shape);
TraceableTensorShape(const at::Tensor & shape);

/// Slice the shape, semantically the same as ArrayRef::slice, but traceable.
TraceableTensorShape slice(Size start, Size end) const;
Expand All @@ -55,7 +55,7 @@ struct TraceableTensorShape : public torch::SmallVector<TraceableSize, 8>
TensorShape concrete() const;

/// @return the shape represented as a scalar tensor (possibly traceable)
torch::Tensor as_tensor() const;
at::Tensor as_tensor() const;
};

/// Comparison operators
Expand Down
2 changes: 1 addition & 1 deletion include/neml2/jit/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace utils
{
/// @brief Extract the batch shape of a tensor given batch dimension
/// The extracted batch shape will be _traceable_. @see neml2::TraceableTensorShape
TraceableTensorShape extract_batch_sizes(const torch::Tensor & tensor, Size batch_dim);
TraceableTensorShape extract_batch_sizes(const at::Tensor & tensor, Size batch_dim);

template <typename... S>
TraceableTensorShape add_traceable_shapes(const S &... shape);
Expand Down
10 changes: 5 additions & 5 deletions include/neml2/misc/defaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ namespace neml2
*/
///@{
/// Default floating point tensor options
torch::TensorOptions & default_tensor_options();
TensorOptions & default_tensor_options();
/// Default integral tensor options
torch::TensorOptions & default_integer_tensor_options();
TensorOptions & default_integer_tensor_options();
/// Default floating point type
torch::Dtype & default_dtype();
Dtype & default_dtype();
/// Default integral type
torch::Dtype & default_integer_dtype();
Dtype & default_integer_dtype();
/// Default device
torch::Device & default_device();
Device & default_device();
///@}

/// @name Default tolerances
Expand Down
Loading

0 comments on commit 31022c8

Please sign in to comment.