Skip to content

Commit

Permalink
Gary jit issue
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Feb 12, 2025
1 parent 8dc160e commit b70f839
Show file tree
Hide file tree
Showing 107 changed files with 430 additions and 1,178 deletions.
39 changes: 12 additions & 27 deletions src/neml2/dispatcher/ValueMapLoader.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,19 @@ broadcast_batch_size(const ValueMap & value_map, Size batch_dim)
{
Size size = 0;
for (auto && [key, tensor] : value_map)
{
auto l_size = tensor.batch_dim() > batch_dim
? std::max(size, tensor.batch_size(batch_dim).concrete())
: 0;
size = std::max(size, l_size);
}
size = std::max(size, tensor.batch_size(batch_dim).concrete());
for (auto && [key, tensor] : value_map)
{
if (tensor.batch_dim() > abs(batch_dim))
{
auto s = tensor.batch_size(batch_dim).concrete();
neml_assert(s == 1 || s == size,
"Batch sizes along batch dimension ",
batch_dim,
" are not compatible. Expected 1 or ",
size,
", got ",
s,
".");
}
auto s = tensor.batch_size(batch_dim).concrete();
neml_assert(s == 1 || s == size,
"Batch sizes along batch dimension ",
batch_dim,
" are not compatible. Expected 1 or ",
size,
", got ",
s,
".");
}
neml_assert(size > 0, "No tensor with batch dimension ", batch_dim, " found.");
return size;
}

Expand All @@ -76,14 +67,8 @@ ValueMapLoader::generate(std::size_t n)

ValueMap work;
for (auto && [key, tensor] : _value_map)
work[key] = (tensor.batch_dim() <= abs(_batch_dim) || tensor.size(_batch_dim) == 1)
? tensor
: tensor.batch_slice(_batch_dim, slice);

std::cout << "SPLIT" << std::endl;
for (auto && [key, tensor] : work)
std::cout << key << " " << tensor.sizes() << std::endl;
work[key] = tensor.size(_batch_dim) == 1 ? tensor : tensor.batch_slice(_batch_dim, slice);

return {m, std::move(work)};
}
} // namespace neml2
} // namespace neml2
6 changes: 0 additions & 6 deletions src/neml2/dispatcher/valuemap_helpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,8 @@ valuemap_cat_reduce(std::vector<ValueMap> && results, Size batch_dim)

// Concatenate the tensors
ValueMap ret;
std::cout << "REJOIN" << std::endl;
for (auto && [name, values] : vars)
{
std::cout << name << std::endl;
for (auto && value : values)
std::cout << value.sizes() << std::endl;
ret[name] = math::batch_cat(values, batch_dim);
}

return ret;
}
Expand Down
30 changes: 24 additions & 6 deletions src/neml2/drivers/TransientDriver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,28 @@ TransientDriver::apply_ic()
set_ic<T>(_result_out[0], input_options(), "ic_" #T "_names", "ic_" #T "_values", _device)
FOR_ALL_TENSORBASE(SET_IC_);

// Figure out what the batch size for our default zero ICs should be
std::vector<Tensor> defined;
for (const auto & var : _model.output_axis().variable_names())
if (_result_out[0].count(var))
defined.push_back(_result_out[0][var]);
for (const auto & [key, value] : _in)
defined.push_back(value);
const auto batch_shape = utils::broadcast_batch_sizes(defined);

// Variables without a user-defined IC are initialized to zeros
for (auto && [name, var] : _model.output_variables())
if (!_result_out[0].count(name))
_result_out[0][name] =
Tensor::zeros(utils::add_shapes(var.list_sizes(), var.base_sizes())).to(_device);
{
if (batch_shape.size() > 0)
_result_out[0][name] = Tensor::zeros(utils::add_shapes(var.list_sizes(), var.base_sizes()))
.to(_device)
.batch_unsqueeze(0)
.batch_expand(batch_shape);
else
_result_out[0][name] =
Tensor::zeros(utils::add_shapes(var.list_sizes(), var.base_sizes())).to(_device);
}
}

void
Expand Down Expand Up @@ -302,14 +319,15 @@ TransientDriver::solve_step()

ValueMapLoader loader(_in, 0);
WorkDispatcher<ValueMap, ValueMap, ValueMap, ValueMap, ValueMap> dispatcher(
[&model = _model](ValueMap && x, torch::Device /*device*/) -> ValueMap
{ return model.value(std::move(x)); },
[&model = _model](ValueMap && x, torch::Device device) -> ValueMap
{
model.to(device);
return model.value(std::move(x));
},
red,
&valuemap_move_device,
post);

std::cout << "IN" << std::endl;
std::cout << _in << std::endl;
_result_out[_step_count] = dispatcher.run(loader, *_scheduler);
}
else
Expand Down
14 changes: 1 addition & 13 deletions tests/dispatcher/test_ValueMapLoader.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ TEST_CASE("ValueMapLoader", "[dispatcher]")
const auto strain = SR2::linspace(strain0, strain1, 100, 1);
const auto temperature_name = VariableName{"forces", "temperature"};
const auto temperature = Scalar::full(300).batch_expand({5, 1, 5});
const auto unexpanded_name = VariableName("state", "not_expanded");
const auto unexpanded = SR2::fill(1.0, 2.0, 3.0);
const auto value_map = ValueMap{
{strain_name, strain}, {temperature_name, temperature}, {unexpanded_name, unexpanded}};
const auto value_map = ValueMap{{strain_name, strain}, {temperature_name, temperature}};

ValueMapLoader loader(value_map, 1);
REQUIRE(loader.total() == 100);
Expand All @@ -59,9 +56,6 @@ TEST_CASE("ValueMapLoader", "[dispatcher]")
REQUIRE(work[temperature_name].batch_sizes() == TensorShape{5, 1, 5});
REQUIRE(work[temperature_name].base_sizes() == TensorShape{});
REQUIRE(torch::allclose(work[temperature_name], temperature.slice(1, 0, 1)));
REQUIRE(work[unexpanded_name].batch_sizes() == TensorShape{});
REQUIRE(work[unexpanded_name].base_sizes() == TensorShape{6});
REQUIRE(torch::allclose(work[unexpanded_name], unexpanded));

REQUIRE(loader.has_more());
std::tie(n, work) = loader.next(2);
Expand All @@ -73,9 +67,6 @@ TEST_CASE("ValueMapLoader", "[dispatcher]")
REQUIRE(work[temperature_name].batch_sizes() == TensorShape{5, 1, 5});
REQUIRE(work[temperature_name].base_sizes() == TensorShape{});
REQUIRE(torch::allclose(work[temperature_name], temperature.slice(1, 0, 1)));
REQUIRE(work[unexpanded_name].batch_sizes() == TensorShape{});
REQUIRE(work[unexpanded_name].base_sizes() == TensorShape{6});
REQUIRE(torch::allclose(work[unexpanded_name], unexpanded));

REQUIRE(loader.has_more());
std::tie(n, work) = loader.next(1000);
Expand All @@ -87,9 +78,6 @@ TEST_CASE("ValueMapLoader", "[dispatcher]")
REQUIRE(work[temperature_name].batch_sizes() == TensorShape{5, 1, 5});
REQUIRE(work[temperature_name].base_sizes() == TensorShape{});
REQUIRE(torch::allclose(work[temperature_name], temperature.slice(1, 0, 1)));
REQUIRE(work[unexpanded_name].batch_sizes() == TensorShape{});
REQUIRE(work[unexpanded_name].base_sizes() == TensorShape{6});
REQUIRE(torch::allclose(work[unexpanded_name], unexpanded));

REQUIRE(!loader.has_more());
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
elastic_stretch plastic_deformation_rate plastic_spin
sum_slip_rates slip_rule slip_strength voce_hardening
integrate_slip_hardening integrate_elastic_strain integrate_orientation"
jit = false
[]
[model]
type = ImplicitUpdate
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
[f0]
type = Scalar
values = '0.01'
batch_shape = '(20)'
[]
[]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!include model.i

[Schedulers]
[scheduler]
type = SimpleScheduler
device = cpu
batch_size = 8
[]
[]

[Drivers]
[driver]
scheduler = scheduler
[]
[]
Loading

0 comments on commit b70f839

Please sign in to comment.