Skip to content

Commit

Permalink
Fix jacobian collection
Browse files Browse the repository at this point in the history
  • Loading branch information
hugary1995 authored and reverendbedford committed Jan 2, 2025
1 parent 5d22321 commit c4948a0
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 21 deletions.
6 changes: 6 additions & 0 deletions include/neml2/models/BufferStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ class BufferStore
template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_buffer(const std::string & name, const std::string & input_option_name);

/// Assign stack to buffers
void assign_buffer_stack(torch::jit::Stack & stack);

/// Collect stack from buffers
torch::jit::Stack collect_buffer_stack() const;

private:
NEML2Object * _object;

Expand Down
10 changes: 5 additions & 5 deletions runner/benchmark/tcpsingle/model.i
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@
[Data]
[crystal_geometry]
type = CubicCrystal
lattice_parameter = "a"
slip_directions = "sdirs"
slip_planes = "splanes"
lattice_parameter = 'a'
slip_directions = 'sdirs'
slip_planes = 'splanes'
[]
[]

Expand All @@ -127,8 +127,8 @@
type = LinearIsotropicElasticity
coefficients = '1e5 0.3'
coefficient_types = 'YOUNGS_MODULUS POISSONS_RATIO'
strain = "state/elastic_strain"
stress = "state/internal/cauchy_stress"
strain = 'state/elastic_strain'
stress = 'state/internal/cauchy_stress'
[]
[resolved_shear]
type = ResolvedShear
Expand Down
35 changes: 35 additions & 0 deletions src/neml2/models/BufferStore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,39 @@ BufferStore::declare_buffer(const std::string & name, const std::string & input_
template const T & BufferStore::declare_buffer<T>(const std::string &, const CrossRef<T> &); \
template const T & BufferStore::declare_buffer<T>(const std::string &, const std::string &)
FOR_ALL_TENSORBASE(BUFFERSTORE_INTANTIATE_TENSORBASE);

void
BufferStore::assign_buffer_stack(torch::jit::Stack & stack)
{
const auto & buffers = _object->host<BufferStore>()->named_buffers();

neml_assert_dbg(stack.size() >= buffers.size(),
"Stack size (",
stack.size(),
") is smaller than the number of buffers in the model (",
buffers.size(),
").");

// Last n tensors in the stack are the buffers
std::size_t i = stack.size() - buffers.size();
for (auto && [name, buffer] : buffers)
{
const auto tensor = stack[i++].toTensor();
buffer = Tensor(tensor, tensor.dim() - Tensor(buffer).base_dim());
}

// Drop the input variables from the stack
torch::jit::drop(stack, buffers.size());
}

torch::jit::Stack
BufferStore::collect_buffer_stack() const
{
const auto & buffers = _object->host<BufferStore>()->named_buffers();
torch::jit::Stack stack;
stack.reserve(buffers.size());
for (auto && [name, buffer] : buffers)
stack.push_back(Tensor(buffer));
return stack;
}
} // namespace neml2
57 changes: 44 additions & 13 deletions src/neml2/models/Model.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,11 @@ Model::value(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(true, false, false);
return collect_output();

const auto values = collect_output();
clear_input();
clear_output();
return values;
}

std::tuple<ValueMap, DerivMap>
Expand All @@ -356,7 +360,12 @@ Model::value_and_dvalue(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(true, true, false);
return {collect_output(), collect_output_derivatives()};

const auto values = collect_output();
const auto derivs = collect_output_derivatives();
clear_input();
clear_output();
return {values, derivs};
}

DerivMap
Expand All @@ -366,7 +375,11 @@ Model::dvalue(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(false, true, false);
return collect_output_derivatives();

const auto derivs = collect_output_derivatives();
clear_input();
clear_output();
return derivs;
}

std::tuple<ValueMap, DerivMap, SecDerivMap>
Expand All @@ -376,7 +389,13 @@ Model::value_and_dvalue_and_d2value(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(true, true, true);
return {collect_output(), collect_output_derivatives(), collect_output_second_derivatives()};

const auto values = collect_output();
const auto derivs = collect_output_derivatives();
const auto secderivs = collect_output_second_derivatives();
clear_input();
clear_output();
return {values, derivs, secderivs};
}

std::tuple<DerivMap, SecDerivMap>
Expand All @@ -386,7 +405,12 @@ Model::dvalue_and_d2value(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(false, true, true);
return {collect_output_derivatives(), collect_output_second_derivatives()};

const auto derivs = collect_output_derivatives();
const auto secderivs = collect_output_second_derivatives();
clear_input();
clear_output();
return {derivs, secderivs};
}

SecDerivMap
Expand All @@ -396,7 +420,11 @@ Model::d2value(const ValueMap & in)
assign_input(in);
zero_output();
forward_maybe_jit(false, false, true);
return collect_output_second_derivatives();

const auto secderivs = collect_output_second_derivatives();
clear_input();
clear_output();
return secderivs;
}

Model *
Expand Down Expand Up @@ -427,13 +455,16 @@ Model::provided_items() const
void
Model::assign_input_stack(torch::jit::Stack & stack)
{
neml_assert_dbg(stack.size() ==
input_axis().nvariable() + host<ParameterStore>()->named_parameters().size(),
"Stack size (",
stack.size(),
") must equal to the number of input variables and parameters in the model (",
input_axis().nvariable() + host<ParameterStore>()->named_parameters().size(),
").");
#ifndef NDEBUG
const auto nstack = input_axis().nvariable() + host<ParameterStore>()->named_parameters().size();
neml_assert_dbg(
stack.size() == nstack,
"Stack size (",
stack.size(),
") must equal to the number of input variables, parameters, and buffers in the model (",
nstack,
").");
#endif

assign_parameter_stack(stack);
VariableStore::assign_input_stack(stack);
Expand Down
8 changes: 5 additions & 3 deletions src/neml2/models/VariableStore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ VariableStore::collect_output_stack(bool out, bool dout, bool d2out) const
for (const auto & xvar : xvars)
{
const auto & deriv = derivs.find(xvar);
sparsity.push_back(deriv != derivs.end() ? 1 : 0);
sparsity.push_back(deriv == derivs.end() || !input_variable(xvar).is_dependent() ? 0 : 1);
if (sparsity.back())
stacklist.push_back(deriv->second);
}
Expand All @@ -342,11 +342,13 @@ VariableStore::collect_output_stack(bool out, bool dout, bool d2out) const
for (const auto & x1var : xvars)
{
const auto & x1derivs = derivs.find(x1var);
if (x1derivs != derivs.end())
if (x1derivs != derivs.end() && input_variable(x1var).is_dependent())
for (const auto & x2var : xvars)
{
const auto & x1x2deriv = x1derivs->second.find(x2var);
sparsity.push_back(x1x2deriv != x1derivs->second.end() ? 1 : 0);
sparsity.push_back(
x1x2deriv == x1derivs->second.end() || !input_variable(x2var).is_dependent() ? 0
: 1);
if (sparsity.back())
stacklist.push_back(x1x2deriv->second);
}
Expand Down

0 comments on commit c4948a0

Please sign in to comment.