diff --git a/coaster-nn/src/frameworks/native/helper.rs b/coaster-nn/src/frameworks/native/helper.rs index 28a4ed20f..1e8d29adc 100644 --- a/coaster-nn/src/frameworks/native/helper.rs +++ b/coaster-nn/src/frameworks/native/helper.rs @@ -307,17 +307,32 @@ macro_rules! impl_ops_softmax_for { x: &SharedTensor<$t>, result: &mut SharedTensor<$t>, ) -> Result<(), Error> { + // Input tensor must have at least 2 dimensions. + // First dimension is treated as a batch number. + assert!( + x.desc().size() > 1, + "Input tensor for softmax must have at least 2 dimensions, got {:?}", + x.desc() + ); + + let batch_size = x.desc()[0]; + let item_size = x.desc().iter().skip(1).fold(1, |acc, v| acc * v); + let xs = read!(x, $t, self); let rs = write_only!(result, $t, self); map1(xs, rs, |v| v.exp())?; - let mut sum: $t = 0.0; // iter_arith is not stable yet - for r in &*rs { - sum += *r; - } - for r in rs { - *r /= sum; + for i in 0..batch_size { + let batch_item = &mut rs[i * item_size..][..item_size]; + + let mut sum: $t = 0.0; // iter_arith is not stable yet + for r in &*batch_item { + sum += *r; + } + for r in &mut *batch_item { + *r /= sum; + } } Ok(()) } @@ -329,16 +344,32 @@ macro_rules! impl_ops_softmax_for { x_diff: &SharedTensor<$t>, result_diff: &mut SharedTensor<$t>, ) -> Result<(), Error> { + let batch_size = x.desc()[0]; + let item_size = x.desc().iter().skip(1).fold(1, |acc, v| acc * v); + let xs = read!(x, $t, self); let dxs = read!(x_diff, $t, self); let drs = write_only!(result_diff, $t, self); - let mut dot: $t = 0.0; - for (t, dt) in xs.iter().zip(dxs.iter()) { - dot += t * dt; + for i in 0..batch_size { + let batch_item_in = &xs[i * item_size..][..item_size]; + let batch_item_diff_in = &dxs[i * item_size..][..item_size]; + let batch_item_out = &mut drs[i * item_size..][..item_size]; + + let mut dot: $t = 0.0; + for (t, dt) in batch_item_in.iter().zip(batch_item_diff_in.iter()) { + dot += t * dt; + } + + map2( + batch_item_in, + batch_item_diff_in, + batch_item_out, + |t, dt| t * (dt - dot), + )?; } - map2(xs, dxs, drs, |t, dt| t * (dt - dot)) + Ok(()) } } }; @@ -354,20 +385,37 @@ macro_rules! impl_ops_log_softmax_for { x: &SharedTensor<$t>, result: &mut SharedTensor<$t>, ) -> Result<(), $crate::co::error::Error> { + // Input tensor must have at least 2 dimensions. + // First dimension is treated as a batch number. + assert!( + x.desc().size() > 1, + "Input tensor for softmax must have at least 2 dimensions, got {:?}", + x.desc() + ); + + let batch_size = x.desc()[0]; + let item_size = x.desc().iter().skip(1).fold(1, |acc, v| acc * v); + let xs = read!(x, $t, self); let rs = write_only!(result, $t, self); - let max_x = xs - .iter() - .fold(::std::$t::NEG_INFINITY, |acc, &t| acc.max(t)); + for i in 0..batch_size { + let batch_item_in = &xs[i * item_size..][..item_size]; + let batch_item_out = &mut rs[i * item_size..][..item_size]; + let max_x = batch_item_in + .iter() + .fold(::std::$t::NEG_INFINITY, |acc, &t| acc.max(t)); - let mut logsum: $t = 0.0; - for t in xs { - logsum += (-(max_x - t)).exp(); + let mut logsum: $t = 0.0; + for t in batch_item_in { + logsum += (*t - max_x).exp(); + } + logsum = max_x + logsum.ln(); + + map1(batch_item_in, batch_item_out, |t| t - logsum)?; } - logsum = max_x + logsum.ln(); - map1(xs, rs, |t| t - logsum) + Ok(()) } fn log_softmax_grad( @@ -376,15 +424,31 @@ macro_rules! impl_ops_log_softmax_for { x_diff: &SharedTensor<$t>, result_diff: &mut SharedTensor<$t>, ) -> Result<(), $crate::co::error::Error> { + let batch_size = x.desc()[0]; + let item_size = x.desc().iter().skip(1).fold(1, |acc, v| acc * v); + let xs = read!(x, $t, self); let dxs = read!(x_diff, $t, self); let drs = write_only!(result_diff, $t, self); - let mut sum: $t = 0.0; - for &grad_val in dxs.iter() { - sum += grad_val; + for i in 0..batch_size { + let batch_item_in = &xs[i * item_size..][..item_size]; + let batch_item_diff_in = &dxs[i * item_size..][..item_size]; + let batch_item_out = &mut drs[i * item_size..][..item_size]; + + let mut sum: $t = 0.0; + for &grad_val in batch_item_diff_in.iter() { + sum += grad_val; + } + map2( + batch_item_in, + batch_item_diff_in, + batch_item_out, + |t, dt| dt - t.exp() * sum, + )?; } - map2(xs, dxs, drs, |t, dt| dt - t.exp() * sum) + + Ok(()) } } }; diff --git a/coaster-nn/src/plugin.rs b/coaster-nn/src/plugin.rs index bf7f09889..29b6f6d0a 100644 --- a/coaster-nn/src/plugin.rs +++ b/coaster-nn/src/plugin.rs @@ -621,7 +621,10 @@ pub trait Convolution: NN { /// Provides the functionality for a Backend to support Softmax operations. pub trait Softmax: NN { /// Computes a [Softmax][softmax] over the input Tensor `x`. - /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function + /// [softmax]: https://en.wikipedia.org/wiki/Softmax_function. + /// Tensor must have more than one dimensions: N,D1,..., where first dimension N + /// is interpreted as the batch size. Softmax operation is applied independently + /// to each batch item over D1,... . /// /// Saves the result to `result`. fn softmax( @@ -645,6 +648,9 @@ pub trait Softmax: NN { /// Provides the functionality for a Backend to support LogSoftmax operations. pub trait LogSoftmax: NN { /// Computes a logarithmic softmax over the input Tensor `x`. + /// Tensor must have more than one dimensions: N,D1,..., where first dimension N + /// is interpreted as the batch size. LogSoftmax operation is applied independently + /// to each batch item over D1,... . /// /// Saves the result to `result`. fn log_softmax( diff --git a/coaster-nn/src/tests/softmax.rs b/coaster-nn/src/tests/softmax.rs index 7739b8d92..1b3add6c0 100644 --- a/coaster-nn/src/tests/softmax.rs +++ b/coaster-nn/src/tests/softmax.rs @@ -6,7 +6,7 @@ use crate::co::prelude::*; use crate::plugin::{LogSoftmax, Softmax}; use crate::tests::{filled_tensor, tensor_assert_eq, tensor_assert_eq_tensor, Epsilon}; -const DIMS: [usize; 3] = [4, 1, 3]; +const DIMS: [usize; 4] = [1, 4, 1, 3]; const IN: [f64; 12] = [ -0.3768541784373798341, diff --git a/juice-examples/mnist-image-multiclass-classification/src/main.rs b/juice-examples/mnist-image-multiclass-classification/src/main.rs index d306d318c..e06c0683f 100644 --- a/juice-examples/mnist-image-multiclass-classification/src/main.rs +++ b/juice-examples/mnist-image-multiclass-classification/src/main.rs @@ -10,16 +10,14 @@ use co::frameworks::cuda::get_cuda_backend; #[cfg(not(feature = "cuda"))] use co::frameworks::native::get_native_backend; use co::prelude::*; -use juice::layer::*; -use juice::layers::*; -use juice::solver::*; +use juice::net::*; +use juice::train::*; use juice::util::*; use juice_utils::{download_datasets, unzip_datasets}; use mnist::{Mnist, MnistBuilder}; use serde::Deserialize; -use std::rc::Rc; -use std::sync::{Arc, RwLock}; +// TODO: Add a choice for the optimizer (SGD or Adam). const MAIN_USAGE: &str = " Juice Examples @@ -133,51 +131,33 @@ fn main() { } #[cfg(all(feature = "cuda"))] -fn add_conv_net( - mut net_cfg: SequentialConfig, - batch_size: usize, - pixel_dim: usize, -) -> SequentialConfig { - net_cfg.add_layer(LayerConfig::new( - "reshape", - ReshapeConfig::of_shape(&[batch_size, 1, pixel_dim, pixel_dim]), - )); - net_cfg.add_layer(LayerConfig::new( +fn add_conv_net(mut net_cfg: SequentialConfig) -> SequentialConfig { + net_cfg.add_layer( "conv", ConvolutionConfig { - num_output: 20, - filter_shape: vec![5], - padding: vec![0], - stride: vec![1], + feature_maps: 20, + kernel_size: 5, + padding: 0, + stride: 1, }, - )); - net_cfg.add_layer(LayerConfig::new( + ); + net_cfg.add_layer( "pooling", PoolingConfig { mode: PoolingMode::Max, - filter_shape: vec![2], - padding: vec![0], - stride: vec![2], + window_size: 2, + padding: 0, + stride: 2, }, - )); - net_cfg.add_layer(LayerConfig::new( - "linear1", - LinearConfig { output_size: 500 }, - )); - net_cfg.add_layer(LayerConfig::new("sigmoid", LayerType::Sigmoid)); - net_cfg.add_layer(LayerConfig::new( - "linear2", - LinearConfig { output_size: 10 }, - )); + ); + net_cfg.add_layer("linear1", LinearConfig { output_size: 500 }); + net_cfg.add_layer("sigmoid", LayerConfig::Sigmoid); + net_cfg.add_layer("linear2", LinearConfig { output_size: 10 }); net_cfg } #[cfg(not(feature = "cuda"))] -fn add_conv_net( - _net_cfg: SequentialConfig, - _batch_size: usize, - _pixel_dim: usize, -) -> SequentialConfig { +fn add_conv_net(_net_cfg: SequentialConfig) -> SequentialConfig { println!( "Currently Juice does not have a native pooling function to use with Conv Nets - you can either try the CUDA implementation, or use a different type of layer" @@ -185,32 +165,15 @@ fn add_conv_net( panic!() } -fn add_mlp( - mut net_cfg: SequentialConfig, - batch_size: usize, - pixel_count: usize, -) -> SequentialConfig { - net_cfg.add_layer(LayerConfig::new( - "reshape", - LayerType::Reshape(ReshapeConfig::of_shape(&[batch_size, pixel_count])), - )); - net_cfg.add_layer(LayerConfig::new( - "linear1", - LayerType::Linear(LinearConfig { output_size: 1568 }), - )); - net_cfg.add_layer(LayerConfig::new("sigmoid", LayerType::Sigmoid)); - net_cfg.add_layer(LayerConfig::new( - "linear2", - LayerType::Linear(LinearConfig { output_size: 10 }), - )); +fn add_mlp(mut net_cfg: SequentialConfig) -> SequentialConfig { + net_cfg.add_layer("linear1", LinearConfig { output_size: 1568 }); + net_cfg.add_layer("sigmoid", LayerConfig::Sigmoid); + net_cfg.add_layer("linear2", LinearConfig { output_size: 10 }); net_cfg } fn add_linear_net(mut net_cfg: SequentialConfig) -> SequentialConfig { - net_cfg.add_layer(LayerConfig::new( - "linear", - LayerType::Linear(LinearConfig { output_size: 10 }), - )); + net_cfg.add_layer("linear", LinearConfig { output_size: 10 }); net_cfg } @@ -247,73 +210,57 @@ fn run_mnist( let batch_size = batch_size.unwrap_or(30); let learning_rate = learning_rate.unwrap_or(0.001f32); - let momentum = momentum.unwrap_or(0f32); + let momentum = momentum.unwrap_or(0.1f32); - let mut net_cfg = SequentialConfig::default(); - net_cfg.add_input("data", &[batch_size, pixel_dim, pixel_dim]); - net_cfg.force_backward = true; + // Create the backend. + #[cfg(all(feature = "cuda"))] + let backend = get_cuda_backend(); + #[cfg(not(feature = "cuda"))] + let backend = Rc::new(get_native_backend()); + // Create the network configuration and the net itself. + let mut net_cfg = SequentialConfig::default(); net_cfg = match &*model_name.unwrap_or("none".to_owned()) { - "conv" => add_conv_net(net_cfg, batch_size, pixel_dim), - "mlp" => add_mlp(net_cfg, batch_size, pixel_count), + "conv" => add_conv_net(net_cfg), + "mlp" => add_mlp(net_cfg), "linear" => add_linear_net(net_cfg), _ => panic!("Unknown model. Try one of [linear, mlp, conv]"), }; - - net_cfg.add_layer(LayerConfig::new("log_softmax", LayerType::LogSoftmax)); - - let mut classifier_cfg = SequentialConfig::default(); - classifier_cfg.add_input("network_out", &[batch_size, 10]); - classifier_cfg.add_input("label", &[batch_size, 1]); - // set up nll loss - let nll_layer_cfg = NegativeLogLikelihoodConfig { num_classes: 10 }; - let nll_cfg = LayerConfig::new("nll", LayerType::NegativeLogLikelihood(nll_layer_cfg)); - classifier_cfg.add_layer(nll_cfg); - - // set up backends - #[cfg(all(feature = "cuda"))] - let backend = Rc::new(get_cuda_backend()); - #[cfg(not(feature = "cuda"))] - let backend = Rc::new(get_native_backend()); - - // set up solver - let mut solver_cfg = SolverConfig { - minibatch_size: batch_size, - base_lr: learning_rate, - momentum, - ..SolverConfig::default() + net_cfg.add_layer("log_softmax", LayerConfig::LogSoftmax); + let mut net = + Network::from_config(&backend, net_cfg, &[vec![1, pixel_dim, pixel_dim]]).unwrap(); + + // Create the trainer. + let trainer_config = TrainerConfig { + batch_size, + objective: LayerConfig::NegativeLogLikelihood, + optimizer: OptimizerConfig::SgdWithMomentum(SgdWithMomentumConfig { momentum }), + learning_rate, + ..Default::default() }; - solver_cfg.network = LayerConfig::new("network", net_cfg); - solver_cfg.objective = LayerConfig::new("classifier", classifier_cfg); - let mut solver = Solver::from_config(backend.clone(), backend.clone(), &solver_cfg); + let mut trainer = Trainer::from_config(&backend, trainer_config, &net, &vec![1]); - // set up confusion matrix + // Set up confusion matrix. let mut classification_evaluator = ::juice::solver::ConfusionMatrix::new(10); classification_evaluator.set_capacity(Some(1000)); - let input = SharedTensor::::new(&[batch_size, pixel_dim, pixel_dim]); - let inp_lock = Arc::new(RwLock::new(input)); - - let label = SharedTensor::::new(&[batch_size, 1]); - let label_lock = Arc::new(RwLock::new(label)); + let mut input = SharedTensor::::new(&[batch_size, pixel_dim, pixel_dim]); + let mut label = SharedTensor::::new(&[batch_size, 1]); for _ in 0..(example_count / batch_size as u32) { // write input let mut targets = Vec::new(); - for (batch_n, (label_val, ref input)) in + for (batch_n, (label_val, ref input_bytes)) in decoded_images.by_ref().take(batch_size).enumerate() { - let mut input_tensor = inp_lock.write().unwrap(); - let mut label_tensor = label_lock.write().unwrap(); - write_batch_sample(&mut input_tensor, &input, batch_n); - write_batch_sample(&mut label_tensor, &[label_val], batch_n); + write_batch_sample(&mut input, &input_bytes, batch_n); + write_batch_sample(&mut label, &[label_val], batch_n); targets.push(label_val as usize); } // train the network! - let infered_out = solver.train_minibatch(inp_lock.clone(), label_lock.clone()); + let mut infered = trainer.train_minibatch(&backend, &mut net, &input, &label); - let mut infered = infered_out.write().unwrap(); let predictions = classification_evaluator.get_predictions(&mut infered); classification_evaluator.add_samples(&predictions, &targets); diff --git a/juice/src/net/common/log_softmax.rs b/juice/src/net/common/log_softmax.rs new file mode 100644 index 000000000..a8e12c0c8 --- /dev/null +++ b/juice/src/net/common/log_softmax.rs @@ -0,0 +1,82 @@ +use crate::co::IBackend; +use crate::conn; +use crate::net::{Context, Descriptor, Layer}; + +#[derive(Debug, Clone)] +pub struct LogSoftmax { + descriptor: Descriptor, +} + +impl LogSoftmax { + pub fn new(mut descriptor: Descriptor) -> Self { + assert_eq!(descriptor.inputs().len(), 1); // Should only be one input. + + descriptor.add_output(descriptor.input(0).unit_shape().clone()); + + LogSoftmax { descriptor: descriptor } + } +} + +impl> Layer for LogSoftmax { + fn compute_output(&self, backend: &B, context: &mut Context) { + let input = context.get_data(self.descriptor.input(0)); + let output = context.acquire_data(self.descriptor.output(0)); + backend.log_softmax(&input.borrow(), &mut output.borrow_mut()).unwrap(); + } + + fn compute_gradients(&self, backend: &B, context: &mut Context) { + let input = context.get_data(self.descriptor.input(0)); + let output = context.get_data(self.descriptor.output(0)); + let output_gradient = context.get_data_gradient(self.descriptor.output(0)); + let input_gradient = context.acquire_data_gradient(self.descriptor.input(0)); + backend + .log_softmax_grad( + &output.borrow(), + &output_gradient.borrow(), + &mut input_gradient.borrow_mut(), + ) + .unwrap(); + } + + fn descriptor(&self) -> &Descriptor { + &self.descriptor + } + + fn descriptor_mut(&mut self) -> &mut Descriptor { + &mut self.descriptor + } +} + +#[cfg(test)] +mod tests { + use coaster::frameworks::native::get_native_backend; + + use crate::net::{layer::testing::*, LayerConfig, Network}; + + #[test] + fn compute() { + let backend = get_native_backend(); + let net = Network::from_config(&backend, LayerConfig::LogSoftmax, &[vec![2]]).unwrap(); + let result = get_net_output(&backend, &net, &create_tensor_2d([[1.0, -2.0], [3.5, 4.0]])); + assert_tensor_eq( + &result.output, + &create_tensor_2d([[-0.04859, -3.04859], [-0.97408, -0.47408]]), + ); + } + + #[test] + fn compute_gradients() { + let backend = get_native_backend(); + let net = Network::from_config(&backend, LayerConfig::LogSoftmax, &[vec![2]]).unwrap(); + let result = get_net_output_and_gradients( + &backend, + &net, + &create_tensor_2d([[1.0, -2.0], [-3.0, 4.0]]), + &create_tensor_2d([[0.4, 0.3], [0.1, 0.2]]), + ); + assert_tensor_eq( + &result.input_gradient, + &create_tensor_2d([[-0.26680, 0.26680], [0.09973, -0.09973]]), + ); + } +} diff --git a/juice/src/net/common/mod.rs b/juice/src/net/common/mod.rs index c6c497ec9..ac0518aeb 100644 --- a/juice/src/net/common/mod.rs +++ b/juice/src/net/common/mod.rs @@ -1,9 +1,13 @@ mod convolution; mod dropout; mod linear; +mod log_softmax; mod pooling; +mod softmax; pub use convolution::*; pub use dropout::*; pub use linear::*; -pub use pooling::*; \ No newline at end of file +pub use log_softmax::*; +pub use pooling::*; +pub use softmax::*; \ No newline at end of file diff --git a/juice/src/net/common/softmax.rs b/juice/src/net/common/softmax.rs new file mode 100644 index 000000000..f2f8f0c48 --- /dev/null +++ b/juice/src/net/common/softmax.rs @@ -0,0 +1,90 @@ +use coaster::ITensorDesc; + +use crate::co::IBackend; +use crate::conn; +use crate::net::{Context, Descriptor, Layer}; + +#[derive(Debug, Clone)] +pub struct Softmax { + descriptor: Descriptor, +} + +impl Softmax { + pub fn new(mut descriptor: Descriptor) -> Self { + assert_eq!(descriptor.inputs().len(), 1); // Should only be one input. + + descriptor.add_output(descriptor.input(0).unit_shape().clone()); + + Softmax { descriptor: descriptor } + } +} + +impl> Layer for Softmax { + fn compute_output(&self, backend: &B, context: &mut Context) { + let input = context.get_data(self.descriptor.input(0)); + let output = context.acquire_data(self.descriptor.output(0)); + + // Since the backend expects the first dimension to be the batch number, + // verify that this assumption holds. + let batch_item_size = input.borrow().desc().iter().skip(1).fold(1, |acc, v| acc * v); + assert_eq!(batch_item_size, self.descriptor.input(0).unit_shape().size()); + + backend.softmax(&input.borrow(), &mut output.borrow_mut()).unwrap(); + } + + fn compute_gradients(&self, backend: &B, context: &mut Context) { + let input = context.get_data(self.descriptor.input(0)); + let output = context.get_data(self.descriptor.output(0)); + let output_gradient = context.get_data_gradient(self.descriptor.output(0)); + let input_gradient = context.acquire_data_gradient(self.descriptor.input(0)); + backend + .softmax_grad( + &output.borrow(), + &output_gradient.borrow(), + &mut input_gradient.borrow_mut(), + ) + .unwrap(); + } + + fn descriptor(&self) -> &Descriptor { + &self.descriptor + } + + fn descriptor_mut(&mut self) -> &mut Descriptor { + &mut self.descriptor + } +} + +#[cfg(test)] +mod tests { + use coaster::frameworks::native::get_native_backend; + + use crate::net::{layer::testing::*, LayerConfig, Network}; + + #[test] + fn compute() { + let backend = get_native_backend(); + let net = Network::from_config(&backend, LayerConfig::Softmax, &[vec![2]]).unwrap(); + let result = get_net_output(&backend, &net, &create_tensor_2d([[1.0, -2.0], [3.5, 4.0]])); + assert_tensor_eq( + &result.output, + &create_tensor_2d([[0.95257, 0.04743], [0.37754, 0.62246]]), + ); + } + + #[test] + fn compute_gradients() { + let backend = get_native_backend(); + let net = Network::from_config(&backend, LayerConfig::Softmax, &[vec![2]]).unwrap(); + let result = get_net_output_and_gradients( + &backend, + &net, + &create_tensor_2d([[1.0, -2.0], [3.5, 4.0]]), + &create_tensor_2d([[0.4, 0.3], [0.1, 0.2]]), + ); + assert_tensor_eq( + &result.input_gradient, + &create_tensor_2d([[0.00452, -0.00452], [-0.02350, 0.02350]]), + ); + } +} diff --git a/juice/src/net/config.rs b/juice/src/net/config.rs index 0409ea347..58b05172b 100644 --- a/juice/src/net/config.rs +++ b/juice/src/net/config.rs @@ -1,5 +1,5 @@ use super::{ - ConvolutionConfig, DropoutConfig, LinearConfig, NegativeLogLikelihoodConfig, PoolingConfig, SequentialConfig, + ConvolutionConfig, DropoutConfig, LinearConfig, PoolingConfig, SequentialConfig, }; /// A configuration of the layer. @@ -9,12 +9,14 @@ pub enum LayerConfig { Convolution(ConvolutionConfig), Dropout(DropoutConfig), Linear(LinearConfig), + LogSoftmax, MeanSquaredError, - NegativeLogLikelihood(NegativeLogLikelihoodConfig), + NegativeLogLikelihood, Pooling(PoolingConfig), Relu, Sequential(SequentialConfig), Sigmoid, + Softmax, // TODO: Add other layer configs. } diff --git a/juice/src/net/layer.rs b/juice/src/net/layer.rs index 59c804b22..e0fbe38d8 100644 --- a/juice/src/net/layer.rs +++ b/juice/src/net/layer.rs @@ -65,12 +65,14 @@ pub fn layer_from_config + 'static>( LayerConfig::Convolution(cfg) => Box::new(Convolution::new(descriptor, cfg)), LayerConfig::Dropout(cfg) => Box::new(Dropout::new(backend, descriptor, cfg)), LayerConfig::Linear(cfg) => Box::new(Linear::new(descriptor, cfg)), + LayerConfig::LogSoftmax => Box::new(LogSoftmax::new(descriptor)), LayerConfig::MeanSquaredError => Box::new(MeanSquaredError::new(descriptor)), - LayerConfig::NegativeLogLikelihood(cfg) => Box::new(NegativeLogLikelihood::new(descriptor, cfg)), + LayerConfig::NegativeLogLikelihood => Box::new(NegativeLogLikelihood::new(descriptor)), LayerConfig::Pooling(cfg) => Box::new(Pooling::new(backend, descriptor, cfg)), LayerConfig::Relu => Box::new(Relu::new(descriptor)), LayerConfig::Sequential(cfg) => Box::new(Sequential::new(backend, descriptor, cfg)?), LayerConfig::Sigmoid => Box::new(Sigmoid::new(descriptor)), + LayerConfig::Softmax => Box::new(Softmax::new(descriptor)), }) } diff --git a/juice/src/net/loss/negative_log_likelihood.rs b/juice/src/net/loss/negative_log_likelihood.rs index 19c7ec975..915b36f61 100644 --- a/juice/src/net/loss/negative_log_likelihood.rs +++ b/juice/src/net/loss/negative_log_likelihood.rs @@ -2,20 +2,13 @@ use crate::co::{IBackend, ITensorDesc}; use crate::net::{Context, Descriptor, Layer}; use crate::util::native_backend; -#[derive(Clone, Debug, Default, PartialEq)] -pub struct NegativeLogLikelihoodConfig { - /// How many different classes can be classified. - pub num_classes: usize, -} - #[derive(Debug)] pub struct NegativeLogLikelihood { descriptor: Descriptor, - num_classes: usize, } impl NegativeLogLikelihood { - pub fn new(descriptor: Descriptor, config: &NegativeLogLikelihoodConfig) -> Self { + pub fn new(descriptor: Descriptor) -> Self { assert_eq!( descriptor.inputs().len(), 2, @@ -26,15 +19,12 @@ impl NegativeLogLikelihood { 1, "Labels must be of [1] shape" ); - + // Note that loss layers don't have outputs, since the result of loss computation is always // a single number which can't then be piped to other layers which expect data to have // shape [batch_size, ...] - NegativeLogLikelihood { - descriptor: descriptor, - num_classes: config.num_classes, - } + NegativeLogLikelihood { descriptor } } } @@ -55,15 +45,17 @@ impl Layer for NegativeLogLikelihood { let native_labels = labels_data.read(native.device()).unwrap().as_slice::(); let mut writable_gradient = vec![0f32; probabilities_gradient.borrow().desc().size()]; + let num_classes = self.descriptor.input(0).unit_shape().size(); + for (batch_n, &label_value) in native_labels.iter().enumerate() { - let index = (self.num_classes * batch_n) + label_value as usize; + let label_index = label_value as usize; + assert!(label_index < num_classes, "Wrong label {} exceeding input size {}", label_index, num_classes); + + let index = (num_classes * batch_n) + label_index; writable_gradient[index] = -1f32; } crate::util::write_to_memory( - probabilities_gradient - .borrow_mut() - .write_only(native.device()) - .unwrap(), + probabilities_gradient.borrow_mut().write_only(native.device()).unwrap(), &writable_gradient, ); }