We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi @rachtibat,
the run_distributed method of the FeatureVisualization class does not take into account the actual batch_size for the multi-target case.
run_distributed
FeatureVisualization
batch_size
Maybe include something like:
if n_samples > batch_size: batches_ = math.ceil(len(conditions) / batch_size) else: batches_ = 1 for b_ in range(batches_): data_broadcast_ = data_broadcast[b_ * batch_size: (b_ + 1) * batch_size] # print(len(conditions), len(data_broadcast_)) conditions_ = conditions[b_ * batch_size: (b_ + 1) * batch_size] # dict_inputs is linked to FeatHooks dict_inputs["sample_indices"] = sample_indices[b_ * batch_size: (b_ + 1) * batch_size] dict_inputs["targets"] = targets[b_ * batch_size: (b_ + 1) * batch_size] # composites are already registered before self.attribution(data_broadcast_, conditions_, None, exclude_parallel=False)
This would fix some GPU memory issue of mine.
Best, Max
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi @rachtibat,
the
run_distributed
method of theFeatureVisualization
class does not take into account the actualbatch_size
for the multi-target case.Maybe include something like:
This would fix some GPU memory issue of mine.
Best,
Max
The text was updated successfully, but these errors were encountered: