diff --git a/Cargo.toml b/Cargo.toml index 487030a..671e792 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,12 @@ slog = {version = "1.2", features = ["max_level_trace"] } slog-term = "1.3.5" tera = "0.6.2" ocl = "0.12.0" +arrayfire = "3.4.1" + + +[replace] +"ocl-core:0.3.2" = { git = "https://github.com/cogciprocate/ocl-core" } +"ocl:0.12.0" = { git = "https://github.com/cogciprocate/ocl/"} [lib] name = "gir" diff --git a/src/api/exprs/special.rs b/src/api/exprs/special.rs index 1f1b9cb..96ae050 100644 --- a/src/api/exprs/special.rs +++ b/src/api/exprs/special.rs @@ -5,7 +5,7 @@ use ops::interface::default::*; use super::super::ids; use std::ops::DerefMut; -pub fn overwrite_update>(arg: T, upd: T) -> Result<()> { +pub fn overwrite_update>(arg: T, upd: T) -> Result { let arg = arg.as_ref(); let upd = upd.as_ref(); same_graph_2(arg, upd)?; @@ -16,7 +16,8 @@ pub fn update>(arg: T, upd: T) -> Result<()> { let arg = arg.as_ref(); let upd = upd.as_ref(); same_graph_2(arg, upd)?; - ids::update(arg.wrapper.get_mut().deref_mut(), arg.id, upd.id) + ids::update(arg.wrapper.get_mut().deref_mut(), arg.id, upd.id)?; + Ok(()) } pub fn cast>(arg: T, data_type: FundamentalType) -> Result { diff --git a/src/api/ids/linalg.rs b/src/api/ids/linalg.rs index 5674456..5a5415a 100644 --- a/src/api/ids/linalg.rs +++ b/src/api/ids/linalg.rs @@ -3,5 +3,5 @@ use graph::*; use errors::*; pub fn mat_mul(graph: &mut Graph, arg0: usize, arg1: usize) -> Result { - Ok(graph.apply_op(Box::new(MatrixMul {}), vec![arg0, arg1])?) + Ok(graph.apply_op(Box::new(MatMul {}), vec![arg0, arg1])?) } diff --git a/src/api/ids/special.rs b/src/api/ids/special.rs index 8eb2ec0..724def7 100644 --- a/src/api/ids/special.rs +++ b/src/api/ids/special.rs @@ -4,9 +4,10 @@ use graph::*; use errors::*; use api::ids; -pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> { - graph.updates.remove(&arg); - update(graph, arg, upd) +pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result { + let existed = remove_update(graph, arg)?; + update(graph, arg, upd)?; + Ok(existed) // let candidates = graph.get_node(arg)?.children.clone(); // let mut old_update = None; // for &c in &candidates { @@ -36,7 +37,7 @@ pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> // } } -pub fn update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> { +pub fn update(graph: &mut Graph, arg:usize, upd: usize) -> Result { // Verify first argument is a Parameter if graph.get_node(arg)?.op.get_meta().name != "Parameter" { return Err(ErrorKind::InvalidArguments( @@ -44,14 +45,30 @@ pub fn update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> { "First argument must be a parameter.".into()).into()) } // Verify that the first argument does not already have an Update - if let Some(u) = graph.updates.get(&arg) { - let ref param_name = graph.nodes[arg].name; - return Err(ErrorKind::InvalidArguments( - "Update".into(), vec![arg, upd], - format!("The parameter '{}' already has an update - {}.", param_name, u)).into()) + for &u in graph.op_map.get("Update").unwrap() { + if graph.nodes[u].ancestors[0] == arg { + let ref param_name = graph.nodes[arg].name; + return Err(ErrorKind::InvalidArguments( + "Update".into(), vec![arg, upd], + format!("The parameter '{}' already has an update - {}.", param_name, u)).into()) + } + } + graph.apply_op(Box::new(Update {}), vec![arg, upd]) +} + +pub fn remove_update(graph: &mut Graph, arg:usize) -> Result { + let updates = graph.op_map.get("Update").unwrap(); + for &u in updates { + if graph.nodes[u].ancestors[0] == arg { + let upd = graph.nodes[u].ancestors[1]; + graph.nodes[arg].children.remove(&u); + graph.nodes[upd].children.remove(&u); + let op = Box::new(Cleared{}); + graph.nodes[u] = op.apply_null(); + return Ok(true) + } } - graph.updates.insert(arg, upd); - Ok(()) + Ok(false) } pub fn cast(graph: &mut Graph, arg: usize, data_type: FundamentalType) -> Result { diff --git a/src/backend/af/backend.rs b/src/backend/af/backend.rs new file mode 100644 index 0000000..569790b --- /dev/null +++ b/src/backend/af/backend.rs @@ -0,0 +1,98 @@ +use primitives::*; +use graph::*; +use backend::*; + +use std::collections::HashMap; +use std::rc::Rc; +use std::cell::{Ref, RefCell}; +use std::io; +use af::function::AfFunction; +use arrayfire as af; +use arrayfire::Array; + + +/// For now this will support only single device +#[derive(Clone)] +pub struct AfBackend { + pub platform: ::arrayfire::Backend, + pub device: i32, + pub parameters: Rc>>, + pub precisions: BackendPrecisions +} + + +impl Default for AfBackend { + fn default() -> Self { + // Todo similar to GraphProps this should be loaded from system file + AfBackend { + platform: ::arrayfire::Backend::DEFAULT, + device: 0, + parameters: Rc::new(RefCell::new(HashMap::new())), + precisions: BackendPrecisions::default() + } + } +} + +impl AfBackend { + pub fn get_param_value(&self, name: &str) -> Ref { + Ref::map(self.parameters.borrow(), |x| x.get(name).unwrap()) + } + + pub fn set_param_value(&mut self, name: &str, value:Array) -> Result<(), String> { + if let Some(v) = self.parameters.borrow().get(name) { + if v.dims() != value.dims() { + return Err(format!("The parameter {} has shape {}, \ + but {} was passed to set_param_value.", name, v.dims(), value.dims())) + } + } + self.parameters.borrow_mut().insert(name.into(), value); + Ok(()) + } +} + +impl Backend for AfBackend { + fn make_function(&self, gf: GraphFunction) + -> AfFunction { + let sym_input_shapes = gf.inputs.iter() + .map(|&id| gf.graph.nodes[id].shape.clone()).collect(); + AfFunction { + initialized: false, + precisions: self.precisions, + gf: gf, + parameters: self.parameters.clone(), + sym_input_shapes: sym_input_shapes, + last_shapes: Vec::new(), + last_deduced: HashMap::new(), + expr_map: HashMap::new() + } + } + + fn get_precisions(&self) -> &BackendPrecisions { + &self.precisions + } + fn set_precisions(&mut self, precisions: BackendPrecisions){ + self.precisions = precisions; + } + fn info(&self, f:&mut io::Write) -> io::Result<()> { + writeln!(f, "Platform: {}", self.platform)?; + writeln!(f, "\tDevices: {}", af::device_count()) + } + + fn general_info(&self, f: &mut io::Write) -> io::Result<()> { + let backend = af::get_active_backend(); + writeln!(f, "Arrayfire Backend General Information:")?; + writeln!(f, "==================================================")?; + for b in af::get_available_backends() { + writeln!(f, "Platform: {}", b)?; + writeln!(f, "\tDevices: {}", af::device_count())?; + af::set_backend(b); + af::info(); + } + af::set_backend(backend); + Ok(()) + } + + fn print_info(&self) -> io::Result<()> { + Ok(::arrayfire::info()) + } +} diff --git a/src/backend/af/function.rs b/src/backend/af/function.rs new file mode 100644 index 0000000..becd1ba --- /dev/null +++ b/src/backend/af/function.rs @@ -0,0 +1,209 @@ +use primitives::*; +use graph::*; +use backend::*; +use errors::*; + +use std::collections::HashMap; +use std::rc::Rc; +use std::cell::RefCell; +use std::ops::Neg; + +use arrayfire as af; +use arrayfire::print_gen; + +#[derive(Clone)] +pub struct AfFunction { + pub initialized: bool, + pub precisions: BackendPrecisions, + pub gf: GraphFunction, + pub parameters: Rc>>, + pub sym_input_shapes: Vec, + pub last_shapes: Vec<[usize; 4]>, + pub last_deduced: HashMap, + pub expr_map: HashMap, +} + +impl AfFunction { + pub fn internal_eval(&mut self, inputs: &[&af::Array]) { + for (&id, input) in self.gf.inputs.iter().zip(inputs) { + self.expr_map.insert(id, (*input).clone()); + } + for (name, &id) in self.gf.parameters.iter() { + self.expr_map.insert(id, self.parameters.borrow().get(name).unwrap().clone()); +// let v = self.expr_map.get(&id).unwrap(); +// println!("Id: {}", id); +// af_print!("Value:",v); + } + let order = self.gf.graph.order.clone(); + for &id in &order { + self.compute_node(id); + } + } +} + +impl CompiledFunction for AfFunction { + fn eval(&mut self, inputs: &[&af::Array]) -> Result> { + // Check correct number of inputs are provided + if inputs.len() != self.gf.inputs.len() { + return Err(ErrorKind::Msg(format!("Incorrect number of inputs. \ + Expected: {}, actual: {}.", self.gf.inputs.len(), inputs.len())).into()); + } + let input_shapes: Vec<[usize;4]> = inputs.iter().map(|x| { + let mut dims = [1, 1, 1, 1]; + for (i, &d) in x.dims().get().iter().enumerate() { + dims[i] = d as usize; + } + dims + }).collect(); + // Check shapes are correct and if they have changed + match verify_shapes(&input_shapes, &self.last_shapes, &self.sym_input_shapes)? { + Some(deduced) => { + self.last_shapes = input_shapes; + self.last_deduced = deduced; + }, + None => {} + } + self.internal_eval(inputs); + let mut result = Vec::new(); + for i in &self.gf.outputs { + result.push(self.expr_map.remove(i).unwrap()); + } +// let output = self.gf.outputs.iter().map(|x| ).collect(); + Ok(result) + } + + fn initialized(&self) -> bool { + self.initialized + } + fn free_memory(&mut self) {} +} + +impl AfFunction { + fn compute_node(&mut self, id: usize) { + let ref node = self.gf.graph.nodes[id]; + let expr_map = &mut self.expr_map; + let op_meta = node.op.get_meta(); + match op_meta.name { + "Input" | "Parameter" => {}, + "Scalar" => { + let (value, _) = *node.op.get_args().unwrap() + .downcast::<(f64, FundamentalType)>().unwrap(); + let result = af::constant(value as f32, af::Dim4::new(&[1, 1, 1, 1])); + expr_map.insert(node.id, result); + }, + "Add" => { + let result = match node.ancestors.len() { + 2 => af::add(expr_map.get(&node.ancestors[0]).unwrap(), + expr_map.get(&node.ancestors[1]).unwrap(), true), + _ => unimplemented!() + }; + expr_map.insert(node.id, result); + }, + "Mul" => { + let result = match node.ancestors.len() { + 2 => af::mul(expr_map.get(&node.ancestors[0]).unwrap(), + expr_map.get(&node.ancestors[1]).unwrap(), true), + _ => unimplemented!() + }; + expr_map.insert(node.id, result); + }, + "MatMul" => { + // println!("{:?} vs {:?}", expr_map.get(&node.ancestors[0]).unwrap().get_type(), + // expr_map.get(&node.ancestors[1]).unwrap().get_type()); + // println!("{:?} vs {:?}", expr_map.get(&node.ancestors[0]).unwrap().dims(), + // expr_map.get(&node.ancestors[1]).unwrap().dims()); + // println!("{:?}", node.ancestors); + let result = match node.ancestors.len() { + 2 => af::matmul(expr_map.get(&node.ancestors[0]).unwrap(), + expr_map.get(&node.ancestors[1]).unwrap(), + af::MatProp::NONE, af::MatProp::NONE), + _ => unimplemented!() + }; + expr_map.insert(node.id, result); + }, + "Reorder" => { + let order = *node.op.get_args().unwrap() + .downcast::<[Axis; 4]>().unwrap(); + let result = if order == [Axis::Axis1, Axis::Axis0, Axis::Axis2, Axis::Axis3] { + af::transpose(expr_map.get(&node.ancestors[0]).unwrap(), false) + } else { + let dims = af::Dim4::new(&[order[0] as u64, + order[1] as u64, + order[2] as u64, + order[3] as u64]); + af::reorder(expr_map.get(&node.ancestors[0]).unwrap(), dims) + }; + expr_map.insert(node.id, result); + }, + "Tanh" => { + let result = af::tanh(expr_map.get(&node.ancestors[0]).unwrap()); + expr_map.insert(node.id, result); + }, + "Neg" => { + let result = expr_map.get(&node.ancestors[0]).unwrap().clone().neg(); + expr_map.insert(node.id, result); + }, + "Sum" => { + let axis = *node.op.get_args().unwrap() + .downcast::<[bool; 4]>().unwrap(); + let mut result = None; + { + let initial = expr_map.get(&node.ancestors[0]).unwrap(); + for i in 0..4 { + if axis[i] { + if result.is_none() { + result = Some(af::sum(initial, i as i32)); + } else { + result = Some(af::sum(&result.unwrap(), i as i32)); + } + } + } + } + expr_map.insert(node.id, result.unwrap()); + }, + "TensorShape" => { + let axis = *node.op.get_args().unwrap() + .downcast::().unwrap(); + let result = { + let parent = expr_map.get(&node.ancestors[0]).unwrap(); + af::constant(parent.dims()[axis as usize] as f32, af::Dim4::new(&[1, 1, 1, 1])) + }; + expr_map.insert(node.id, result); + }, + "Div" => { + let result = { + let parent = expr_map.get(&node.ancestors[0]).unwrap(); + let one = af::constant(1.0f32, af::Dim4::new(&[1, 1, 1, 1])); + af::div(&one, parent, true) + }; + expr_map.insert(node.id, result); + }, + "Broadcast" => { + let result = expr_map.get(&node.ancestors[0]).unwrap().clone(); + expr_map.insert(node.id, result); + }, + "Update" => { + let name = self.gf.graph.nodes[node.ancestors[0]].name.clone(); +// { +// let x = self.parameters.borrow(); +// let p_old = x.get(&name).unwrap(); +// let p_new = expr_map.get(&node.ancestors[1]).unwrap(); +// println!("[{}]{:?} vs [{}]{:?}", +// node.ancestors[0], p_old.dims(), +// node.ancestors[1], p_new.dims()); +// if p_old.dims()[1] == 1 { +// af_print!("Value: ", p_old); +// af_print!("Value: ", p_new); +// } +// } + let upd = expr_map.get(&node.ancestors[1]).unwrap().clone(); + upd.eval(); + self.parameters.borrow_mut().insert(name, upd); + }, + name => { + panic!("Operator {} not implemented.", name) + } + } +// println!("{} - {:?}", id, expr_map.get(&id).map(|x| x.dims())); + } +} diff --git a/src/backend/af/mod.rs b/src/backend/af/mod.rs new file mode 100644 index 0000000..e65a3d6 --- /dev/null +++ b/src/backend/af/mod.rs @@ -0,0 +1,6 @@ +pub mod function; +pub mod backend; + +pub use backend::*; +pub use self::backend::*; +pub use self::function::*; \ No newline at end of file diff --git a/src/backend/common.rs b/src/backend/common.rs index f487924..7846955 100644 --- a/src/backend/common.rs +++ b/src/backend/common.rs @@ -5,31 +5,49 @@ use errors::*; use std::io; use std::collections::HashMap; +#[derive(Debug, Clone, Copy)] +pub struct BackendPrecisions { + pub integer_precision: Precision, + pub float_precision: Precision, + pub complex_precision: Precision +} + +impl Default for BackendPrecisions { + fn default() -> Self { + BackendPrecisions { + integer_precision: Precision::P32, + float_precision: Precision::P32, + complex_precision: Precision::P32, + } + } +} + pub trait CompiledFunction { fn initialized(&self) -> bool; - fn eval(&mut self, inputs: &[TI]) -> Result>; + fn eval(&mut self, inputs: &[&TI]) -> Result>; fn free_memory(&mut self); } #[derive(Debug, Clone, Default)] -pub struct MemoryMap { +pub struct AbstractMemoryMap { + // Maps node id to (offset in memory, size in memory) pub abstract_map: HashMap, - pub current_map: HashMap, - pub current_size: u64 + // (number of booleans, number of integers, number of floats, number of complex) + pub abstract_size: (SymInt, SymInt, SymInt, SymInt), } pub trait Backend: Default { fn info(&self, f: &mut io::Write) -> io::Result<()>; fn general_info(&self, f: &mut io::Write) -> io::Result<()>; - fn print_info(&self) -> io::Result<()> { self.info(&mut io::stdout()) } - fn print_general_info(&self) -> io::Result<()> { self.general_info(&mut io::stdout()) } + fn get_precisions(&self) -> &BackendPrecisions; + fn set_precisions(&mut self, precisions: BackendPrecisions); fn make_function(&self, graph_function: GraphFunction) -> F; } @@ -63,4 +81,43 @@ pub fn verify_shapes(new_shapes: &[[usize; 4]], } else { Ok(None) } +} + +pub fn build_memory_map(gf: &GraphFunction) -> AbstractMemoryMap { + let mut map = HashMap::new(); + let mut offset: SymInt = 0.into(); + let mut b_size: SymInt = 0.into(); + let mut i_size: SymInt = 0.into(); + let mut f_size: SymInt = 0.into(); + let mut c_size: SymInt = 0.into(); + for &i in &gf.graph.order { + let ref node = gf.graph.nodes[i]; + let op_meta = node.op.get_meta(); + match op_meta.name { + "Scalar" | "SymIntInput" | "TensorShape" | "Broadcast" | "Parameter" => {}, + _ => { + let n = gf.graph.nodes[i].shape.elements(); + map.insert(i, (offset.clone(), n.clone())); + match gf.graph.nodes[i].data_type { + FundamentalType::Boolean => { + b_size += &n; + }, + FundamentalType::SignedInt | FundamentalType::UnsignedInt => { + i_size += &n; + }, + FundamentalType::Float => { + f_size += &n; + }, + FundamentalType::Complex => { + c_size += &n; + } + } + offset += &n; + } + } + } + AbstractMemoryMap { + abstract_map: map, + abstract_size: (b_size, i_size, f_size, c_size), + } } \ No newline at end of file diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 672f2fa..64ba8a0 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,4 +1,5 @@ pub mod common; -pub mod opencl; +//pub mod opencl; +pub mod af; pub use self::common::*; \ No newline at end of file diff --git a/src/backend/opencl/backend.rs b/src/backend/opencl/backend.rs index b2f8aff..e822e5e 100644 --- a/src/backend/opencl/backend.rs +++ b/src/backend/opencl/backend.rs @@ -1,11 +1,14 @@ +use primitives::*; use graph::*; use backend::*; use backend::opencl::function::*; -use ocl::{Platform, Device, Context, Queue}; +use ocl::{Platform, Device, Context, Queue, Buffer}; use ocl::core::DeviceInfo; +use ocl::flags::{MemFlags, CommandQueueProperties}; use std::io; use std::collections::HashMap; +use tera::Tera; /// For now this will support only single device @@ -14,7 +17,7 @@ pub struct OpenCLBackend { pub platform: Platform, pub device: Device, pub context: Context, - pub queue: Queue, + pub precisions: BackendPrecisions, } @@ -29,12 +32,11 @@ impl Default for OpenCLBackend { .devices(device) .build() .unwrap(); - let queue = Queue::new(&context, device).unwrap(); OpenCLBackend { platform: platform, device: device, context: context, - queue: queue, + precisions: (Precision::P32, Precision::P32, Precision::P32) } } } @@ -44,16 +46,37 @@ impl Backend for OpenCLBackend { -> OpenCLFunction { let sym_input_shapes = gf.inputs.iter() .map(|&id| gf.graph.nodes[id].shape.clone()).collect(); + let flags = Some(MemFlags::alloc_host_ptr() | MemFlags::read_write()); + let mut kernel_map = HashMap::new(); + let mut tera = compile_templates!("templates/kernels"); + for &i in &gf.graph.order { + let mut context = ::tera::Context::new(); + context.add("b_type", type_to_string(gf.graph.nodes[i].data_type, self.precisions)); + context.add("c_type", "size_t"); + kernel_map.insert(i, tera.render("store.tera", context)); + } OpenCLFunction { - memory_map: MemoryMap::default(), - gf: gf, initialized: false, + precisions: self.precisions, + gf: gf, + memory_map: build_memory_map(&gf), + current_size: 0, sym_input_shapes: sym_input_shapes, last_shapes: Vec::new(), last_deduced: HashMap::new(), + buffer: Buffer::::new(self.queue.clone(), flags, 1 , None).unwrap(), + buffer_map: HashMap::new(), + kernel_map: kernel_map, + queue: self.queue.clone() } } + fn get_precisions(&self) -> &BackendPrecisions { + self.precisions + } + fn set_precisions(&mut self, precisions: BackendPrecisions){ + self.precisions = precisions; + } fn info(&self, f:&mut io::Write) -> io::Result<()> { writeln!(f, "OpenCL Backend Information:")?; // Todo: when String.repeat() becomes stable exchange @@ -91,7 +114,6 @@ impl Backend for OpenCLBackend { writeln!(f, "==================================================")?; Ok(()) } - fn general_info(&self, f: &mut io::Write) -> io::Result<()> { writeln!(f, "OpenCL Backend General Information:")?; writeln!(f, "==================================================")?; @@ -137,64 +159,32 @@ impl Backend for OpenCLBackend { } } -//impl OpenCLBackend { -// pub fn process_graph(&mut self, graph: &Graph) { -// let mut kernel_map = HashMap::new(); -// let mut kernels: Vec = Vec::new(); -// for &id in &graph.order { -// let ref node = graph.nodes[id]; -// let meta = node.op.get_meta(); -// match meta.name { -// "Input" | "Parameter" | "Scalar" => {}, -// "Add" => { -// match node.ancestors.len() { -// 2 => { -// let name = "add_2_float_32"; -// if kernel_map.get(name).is_none() { -// let kernel = format!( -// "__kernel void multiply(__global float* out, -// __global float* in1, -// __global float* in2){{ -// auto id = get_global_id(0); -// out[id] = in1[id] + in2[id]; -// }}"); -// kernel_map.insert(name, kernel); -// } -// }, -// _ => {} -// } -// } -// _ => {} -// } -// } -// } -// -// -// pub fn make_program(&mut self, source: &str) { -// self.program = Some(Program::builder() -// .src(source) -// .devices(self.device) -// .build(&self.context).unwrap()); -// } -// -// pub fn execute_kernel(&self, kernel_name: &str) { -// let dims = &[64]; -// let buffer = Buffer::::new(self.queue.clone(), None, dims, None).unwrap(); -// let kernel = Kernel::new(kernel_name, self.program.as_ref().unwrap(), &self.queue).unwrap() -// .gws(&[10]) -// .arg_buf(&buffer) -// .arg_scl(10.0f32); -// -// let mut event_list = EventList::new(); -// -// let mut result = vec![1.0f32; dims[0]]; -// let mut event = Event::empty(); -// buffer.cmd().write(&result).enq().unwrap(); -// kernel.cmd().enq().unwrap(); -// buffer.cmd().read(&mut result).enew(&mut event).enq().unwrap(); -// event_list.wait().unwrap(); -// println!("{:?}", result); -// } -// -// -//} \ No newline at end of file +pub fn type_to_string(_type: FundamentalType, precisions: &BackendPrecisions) -> String { + match _type { + FundamentalType::Boolean => "bool".into(), + FundamentalType::UnsignedInt => match precisions.integer_precision { + Precision::P8 => unimplemented!(), + Precision::P16 => "uint_16".into(), + Precision::P32 => "uint_32".into(), + Precision::P64 => "uint_64".into(), + }, + FundamentalType::SignedInt => match precisions.integer_precision { + Precision::P8 => unimplemented!(), + Precision::P16 => "int_16".into(), + Precision::P32 => "int_32".into(), + Precision::P64 => "int_64".into(), + }, + FundamentalType::Float => match precisions.float_precision { + Precision::P8 => unimplemented!(), + Precision::P16 => "float_16".into(), + Precision::P32 => "float_32".into(), + Precision::P64 => "float_64".into(), + }, + FundamentalType::UnsignedInt => match precisions.complex_precision { + Precision::P8 => unimplemented!(), + Precision::P16 => unimplemented!(), + Precision::P32 => unimplemented!(), + Precision::P64 => unimplemented!(), + } + } +} diff --git a/src/backend/opencl/function.rs b/src/backend/opencl/function.rs index 882cdc0..c9bb5f0 100644 --- a/src/backend/opencl/function.rs +++ b/src/backend/opencl/function.rs @@ -2,9 +2,13 @@ use primitives::*; use graph::*; use backend::*; use errors::*; +use ocl::{Buffer, Queue}; +use ocl::flags::MemFlags; +use ocl::core::{create_sub_buffer, BufferRegion, Mem}; use std::collections::HashMap; + #[derive(Debug, Clone)] pub struct OpenCLContainer { pub mem: Vec, @@ -13,16 +17,78 @@ pub struct OpenCLContainer { #[derive(Debug, Clone)] pub struct OpenCLFunction { - pub memory_map: MemoryMap, - pub gf: GraphFunction, pub initialized: bool, - pub last_shapes: Vec<[usize;4]>, + pub precisions: BackendPrecisions, + pub gf: GraphFunction, + pub memory_map: AbstractMemoryMap, + pub current_size: usize, pub sym_input_shapes: Vec, - pub last_deduced: HashMap + pub last_shapes: Vec<[usize; 4]>, + pub last_deduced: HashMap, + pub buffer: Buffer, + pub buffer_map: HashMap, + pub kernel_map: HashMap, + pub queue: Queue } impl OpenCLFunction { - fn allocate_buffer(&mut self) { + fn allocate(&mut self) { + // Calculate memory for each type + let size_b = self.memory_map.abstract_size.0.eval(&self.last_deduced).unwrap(); + let size_i = self.memory_map.abstract_size.1 + .eval(&self.last_deduced).unwrap() * + self.precisions.integer_precision as i64; + let size_f = self.memory_map.abstract_size.2 + .eval(&self.last_deduced).unwrap() * + self.precisions.float_precision as i64; + let size_c = self.memory_map.abstract_size.3 + .eval(&self.last_deduced).unwrap() * 2 * + self.precisions.complex_precision as i64; + // Full size of allocation + self.current_size = (size_b + size_i + size_f + size_c) as usize; + // Allocate full memory + let flags = Some(MemFlags::alloc_host_ptr() | MemFlags::read_write()); + self.buffer = Buffer::::new(self.queue.clone(), flags, self.current_size, None).unwrap(); + if new_size > 1024 * 1024 { + println!("Allocating {:.2} MB of memory.", new_size as f64 / (1024.0 * 1024.0)); + } else if new_size > 1024 { + println!("Allocating {:.2} KB of memory.", new_size as f64 / 1024.0); + } else { + println!("Allocating {:.2} B of memory.", new_size); + } + let mut map = HashMap::new(); + // Create buffers for each node + for (&id, &(offset, size)) in self.memory_map.abstract_map.iter() { + let offset = offset.eval(&self.last_deduced).unwrap(); + let size = size.eval(&self.last_deduced).unwrap(); + let sub = match self.gf.graph.nodes[id].data_type { + FundamentalType::Boolean => { + create_sub_buffer::(&self.buffer, + MemFlags::read_write(), + &BufferRegion::new(offset, size)).unwrap(); + }, + FundamentalType::UnsignedInt => match self.precisions.integer_precision { + Precision::P8 => unimplemented!(), + Precision::P16 => { + let offset = offset / 2; + create_sub_buffer::(&self.buffer, + MemFlags::read_write(), + &BufferRegion::new(offset, size)).unwrap(); + } + Precision::P32 => unimplemented!(), + Precision::P64 => unimplemented!(), + create_sub_buffer::(&self.buffer, MemFlags::read_write(), &BufferRegion::new(offset, size)).unwrap(); + }, + FundamentalType::SignedInt => { + create_sub_buffer::(&self.buffer, MemFlags::read_write(), &BufferRegion::new(offset, size)).unwrap(); + }, + _ => { + create_sub_buffer::(&self.buffer, MemFlags::read_write(), &BufferRegion::new(offset, size)).unwrap(); + }, + }; + map.insert(id, sub); + } + // Create sub buffers for each node // Todo // Based on self.last_deduced should evaluate all of the memory needed // and allocate ocl::Buffer accordingly @@ -36,14 +102,6 @@ impl OpenCLFunction { } impl CompiledFunction for OpenCLFunction { - fn initialized(&self) -> bool { - self.initialized - } - - fn free_memory(&mut self) { - // Free all of the OpenCL Buffers - } - fn eval(&mut self, inputs: &[OpenCLContainer]) -> Result> { // Check correct number of inputs are provided if inputs.len() != self.gf.inputs.len() { @@ -56,8 +114,7 @@ impl CompiledFunction for OpenCLFunction { Some(deduced) => { self.last_shapes = input_shapes; self.last_deduced = deduced; - // Allocate memory as needed for the new exact shapes - self.allocate_buffer(); + self.allocate(); }, None => {} } @@ -66,4 +123,11 @@ impl CompiledFunction for OpenCLFunction { // Todo copy outputs ocl::Buffer to OpenCLContainers Ok(Vec::new()) } + + fn initialized(&self) -> bool { + self.initialized + } + fn free_memory(&mut self) { + self.buffer = Buffer::::new(self.queue.clone(), None, 1, None).unwrap(); + } } diff --git a/src/graph.rs b/src/graph.rs index 91c0ed4..3764c9b 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -19,7 +19,7 @@ pub struct ExprData{ pub id: usize, pub name: String, pub ancestors: Vec, - pub children: Vec, + pub children: HashSet, pub op: Box, pub data_type: FundamentalType, pub shape: Shape, @@ -64,7 +64,7 @@ pub struct Graph { pub grad_level: usize, pub scope: Vec, pub op_map: HashMap>, - pub updates: HashMap, + // pub updates: HashMap, pub log: Logger, } @@ -76,16 +76,19 @@ impl Default for Graph { impl Graph { pub fn new(log: Logger) -> Self { - Graph { + let mut graph = Graph { nodes: Vec::new(), order: Vec::new(), props: GraphProperties::default(), grad_level: 0, scope: Vec::new(), op_map: HashMap::new(), - updates: HashMap::new(), + // updates: HashMap::new(), log: log - } + }; + // Todo insert all ops + graph.op_map.insert("Update".into(), Vec::new()); + graph } pub fn scope_str(&self) -> String { @@ -120,7 +123,7 @@ impl Graph { data.scope = self.scope.clone(); // println!("Adding node {:?}", data); for &a in &data.ancestors { - self.nodes[a].children.push(data.id) + self.nodes[a].children.insert(data.id); } // Insert into op_map if !self.op_map.contains_key(data.op.get_meta().name) { @@ -145,14 +148,14 @@ impl Graph { // Check if parameter already exists if let Some(v) = self.op_map.get("Parameter") { let (_, _, name) = *data.op.get_args().unwrap() - .downcast::<(FundamentalType, Shape, String)>().unwrap(); + .downcast::<(FundamentalType, Shape, Vec)>().unwrap(); for &id in v { let (_, _, v_name) = *self.nodes[id].op.get_args().unwrap() - .downcast::<(FundamentalType, Shape, String)>().unwrap(); + .downcast::<(FundamentalType, Shape, Vec)>().unwrap(); if name == v_name { return Err(ErrorKind::Msg( format!("The parameter '{}' already exists \ - in the graph.", name)).into()) + in the graph.", name.join("::"))).into()) } } } @@ -286,8 +289,10 @@ impl Graph { } pub fn parameter(&mut self, data_type: FundamentalType, shape: Shape, name: String) -> Result { + let mut param_name = self.scope.clone(); + param_name.push(name); let op = Box::new(Parameter{ - param_name: self.name_in_scope(&name), + param_name: param_name, data_type: data_type, shape: shape }); @@ -306,9 +311,22 @@ impl Graph { let node: &ExprData = self.nodes.get(id).unwrap(); let op = node.op.clone(); graph.scope = self.nodes[id].scope.clone(); - let new_id = match op.get_meta().name { + match op.get_meta().name { "Input" | "Parameter" | "Scalar" => { - graph.add_node(op.apply_null())? + let new_id = graph.add_node(op.apply_null())?; + provided.insert(id, new_id); + }, + "Update" => if ! discard_updates { + let new_id = { + let arg = &node.ancestors[0]; + let arg = provided.get(arg).ok_or(ErrorKind::Msg( + format!("The argument {} needed for updates is not provided.", arg)))?; + let upd = &node.ancestors[1]; + let upd = provided.get(upd).ok_or(ErrorKind::Msg( + format!("The argument {} needed for updates is not provided.", upd)))?; + ids::update(graph, *arg, *upd)? + }; + provided.insert(id, new_id); }, _ => { let mut ancestors: Vec = Vec::with_capacity(node.ancestors.len()); @@ -318,22 +336,13 @@ impl Graph { ancestors.push(v); } let data = op.apply(graph, ancestors)?; - graph.add_node(data)? + let new_id = graph.add_node(data)?; + provided.insert(id, new_id); } }; - provided.insert(id, new_id); } } graph.scope = init_scope; - if ! discard_updates { - for (a, u) in self.updates.iter() { - let ap = provided.get(a).ok_or(ErrorKind::Msg( - format!("The argument {} needed for updates is not provided.", a)))?; - let up = provided.get(a).ok_or(ErrorKind::Msg( - format!("The argument {} needed for updates is not provided.", u)))?; - ids::update(graph, *ap, *up)?; - } - } Ok(provided) } } @@ -343,8 +352,6 @@ pub struct GraphWrapper { pub graph: Rc> } -pub type MutGraph<'a> = RefMut<'a, Graph>; - impl GraphWrapper { pub fn new(log: Logger) -> Self { GraphWrapper { @@ -418,6 +425,7 @@ pub struct GraphFunction { pub graph: Graph, pub inputs: Vec, pub outputs: Vec, + pub parameters: HashMap, pub unique_symints: HashSet, } @@ -440,9 +448,9 @@ impl GraphFunction { leafs.clone_from_slice(outputs); // Add updates from the graph if ! discard_updates { - for (&var, &upd) in &graph.updates { - leafs.push(var); - leafs.push(upd); + for &u in graph.op_map.get("Update").unwrap() { + leafs.push(graph.nodes[u].ancestors[0]); + leafs.push(graph.nodes[u].ancestors[1]); } } // Add extra updates @@ -477,12 +485,18 @@ impl GraphFunction { node.shape.2.unique_identifiers(&mut unique); node.shape.3.unique_identifiers(&mut unique); } + // Find all of the parameters + let params = sub_graph.op_map.get("Parameter") + .map(|v| v.iter() + .map(|&id| (sub_graph.nodes[id].name.clone(), id)).collect()) + .unwrap_or(HashMap::new()); // Return the function created Ok(GraphFunction{ name: name.unwrap_or("main".into()), graph: sub_graph, inputs: inputs.iter().map(|x| *mapping.get(x).unwrap()).collect(), outputs: outputs.iter().map(|x| *mapping.get(x).unwrap()).collect(), + parameters: params, unique_symints: unique, }) } diff --git a/src/lib.rs b/src/lib.rs index 80f9fe5..27ed14f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,8 @@ extern crate slog_term; #[macro_use] extern crate tera; extern crate ocl; +#[macro_use(af_print)] +extern crate arrayfire; pub mod primitives; pub mod errors; diff --git a/src/main.rs b/src/main.rs index a3255e3..853a508 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,14 @@ #[macro_use] extern crate gir; -extern crate slog_term; - use gir::api::*; -use gir::backend::opencl::*; -use gir::backend::*; use std::fs::File; +//use gir::backend::opencl::*; +use gir::backend::af::*; + fn main() { let f = make_graph().unwrap(); - compile_and_run(f); + compile_and_run_af(f); } #[allow(unused_variables, unused_mut)] @@ -19,7 +18,7 @@ fn make_graph() -> gir::errors::Result { // Learning rate let alpha = &f_param!(g, (), "alpha")?; // Dummy - let beta = &f_var!(g, ()); +// let beta = &f_var!(g, ()); // Input let x = &f_var!(g, (784, "n"), "input"); // Targets @@ -38,11 +37,11 @@ fn make_graph() -> gir::errors::Result { // Generate SGD updates g.get_mut().scope.push("updates".into()); let updates: Vec<(gir::Expr, gir::Expr)> = params.iter().zip(grads.iter()) - .map(|(ref p, ref g)| ((**p).clone(), **p - alpha * g)).collect(); + .map(|(&& ref p, ref g)| (p.clone(), p - alpha * g)).collect(); g.get_mut().scope.clear(); // Compile function let f = gir::GraphFunction::new_from_expr(&[x.clone(), y.clone()], &[error], - false, &updates[..], Some("test_func".into()))?; + false, &updates[..], Some("test_func".into()))?; println!("{} - {}", g.get().nodes.len(), f.graph.nodes.len()); let mut file = File::create("target/html/foo.dot").unwrap(); gir::export::dot::to_dot(&mut file, &f.graph).unwrap(); @@ -57,32 +56,32 @@ fn make_graph() -> gir::errors::Result { // } //"#; +#[macro_use(af_print)] +extern crate arrayfire as af; +use af::print_gen; + #[allow(unused_variables, unused_mut)] -pub fn compile_and_run(func: gir::GraphFunction) { - let backend = OpenCLBackend::default(); +pub fn compile_and_run_af(func: gir::GraphFunction) { + // Initialize backend + let mut backend = AfBackend::default(); backend.print_general_info().unwrap(); - backend.print_info().unwrap(); + // Initialize parameters + let alpha = af::constant::(0.001, af::Dim4::new(&[1, 1, 1, 1])); + backend.set_param_value("alpha", alpha); + let w1 = af::randn::(af::Dim4::new(&[1, 784, 1, 1])) / 100.0f32; + backend.set_param_value("w1", w1); + let b1 = af::randn::(af::Dim4::new(&[1, 1, 1, 1])) / 100.0f32; + backend.set_param_value("b1", b1); + // Make inputs + let input = af::randu::(af::Dim4::new(&[784, 20, 1, 1])); + let target = af::randu::(af::Dim4::new(&[20, 1, 1, 1])); + let ins = &vec![&input, &target]; + // Compile function let mut f = backend.make_function(func); - let x = OpenCLContainer { - mem: vec![1.0; 784*10], - dims: [784, 10, 1, 1] - }; - let y = OpenCLContainer { - mem: vec![0.5; 10], - dims: [10, 1, 1, 1] - }; - let y_wrong = OpenCLContainer { - mem: vec![0.5; 12], - dims: [12, 1, 1, 1] - }; - // Not correct number of inputs - println!("{:?}", f.eval(&vec![])); - // Incorrect size, e.g. x.dim1 = 10, y_wrong.dim0 = 12 - println!("{:?}", f.eval(&vec![x.clone(), y_wrong])); - // Correct - println!("{:?}", f.eval(&vec![x, y])); - for (ref sym, ref ls) in f.sym_input_shapes.iter().zip(f.last_shapes.iter()) { - println!("{} - {:?}", sym, ls); + // Run 100 iterations + let mut result = [0.0f32]; + for i in 0..100 { + f.eval(ins).unwrap().pop().unwrap().host(&mut result); + println!("Iteration {}: {:.5e}", i, result[0]); } - println!("{:?}", f.last_deduced) -} +} \ No newline at end of file diff --git a/src/ops/input.rs b/src/ops/input.rs index 4ed551c..b172cb5 100644 --- a/src/ops/input.rs +++ b/src/ops/input.rs @@ -2,6 +2,7 @@ use ops::interface::*; use primitives::*; use graph::*; use errors::*; +use std::collections::HashSet; use symbolic_polynomials::variable; use std::any::Any; @@ -23,7 +24,7 @@ impl Operator for Input { id: 0, name: "".into(), ancestors: Vec::new(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: self.data_type, shape: self.shape.clone(), @@ -83,7 +84,7 @@ impl Operator for Input { #[derive(Debug, Clone)] pub struct Parameter { - pub param_name: String, + pub param_name: Vec, pub data_type: FundamentalType, pub shape: Shape } @@ -98,9 +99,9 @@ impl Operator for Parameter { fn apply_null(&self) -> ExprData { ExprData{ id: 0, - name: self.param_name.clone(), + name: format!("{}", self.param_name.join("::")), ancestors: Vec::new(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: self.data_type, shape: self.shape.clone(), @@ -176,7 +177,7 @@ impl Operator for Scalar { id: 0, name: "Scalar".into(), ancestors: Vec::new(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: self.data_type, shape: Shape::scalar_shape(), @@ -239,7 +240,7 @@ impl Operator for SymIntInput { id: 0, name: "SymInt".into(), ancestors: Vec::new(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: FundamentalType::UnsignedInt, shape: Shape::scalar_shape(), @@ -285,6 +286,55 @@ impl Operator for SymIntInput { } } +#[derive(Debug, Clone)] +pub struct Cleared {} + +impl Operator for Cleared { + #[allow(unused_variables, unused_mut)] + fn reverse_diff(&self, g: &mut Graph, x: usize, dx: usize, flow_tree: &Vec) + -> Result> { + unimplemented!() + } + fn apply_null(&self) -> ExprData { + ExprData{ + id: 0, + name: "Cleared".into(), + ancestors: Vec::new(), + children: HashSet::new(), + op: self.clone_box(), + data_type: FundamentalType::Boolean, + shape: Shape::scalar_shape(), + is_input_dependent: false, + is_differentiable: false, + matrix_positivity: MatrixPositivity::PositiveDefinite, + matrix_symmetry: MatrixSymmetry::Symmetric, + matrix_fill: MatrixFill::Diagonal, + grad_level: 0, + scope: Vec::new(), + sym_int: None + } + } + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + fn get_meta(&self) -> &OperatorMetaData { + static CLEARED: OperatorMetaData = OperatorMetaData{ + name: "Cleared", + arity: Arity::Nullary, + num_outputs: 1, + differential_parents: 0, + ordered_parents: false, + elementwise: false, + type_preserving: false, + reduction: false, + differentiable: false, + scalar_output: true, + shape_operator: false, + fixed_output_type: Some(FundamentalType::Boolean), + }; + &CLEARED + } +} diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 28080f0..b03bef9 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -3,6 +3,7 @@ use graph::*; use errors::*; use api::ids; use std::any::Any; +use std::collections::HashSet; //use std::borrow::Borrow; //use std::cell::RefCell; @@ -55,7 +56,7 @@ pub trait Operator: ::std::fmt::Debug { id: 0, name: "".into(), ancestors: args.clone(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: self.get_data_type(g, &args), shape: self.get_shape(g, &args), @@ -254,8 +255,7 @@ pub mod default { } // Make sure all arguments are up to that shape, if not broadcast them accordingly for a in args.iter_mut() { - if shape != graph.get_node(*a).unwrap().shape && - graph.get_node(*a).unwrap().shape != Shape::scalar_shape() { + if shape != graph.get_node(*a).unwrap().shape { let br: Vec> = Axis::iter().zip(shape_i.iter()) .map(|(&axis, &arg_id)| { if shape.get(axis) != graph.get_node(*a).unwrap().shape.get(axis) { @@ -281,6 +281,17 @@ pub mod default { *a = ids::broadcast(graph, *a, [br[0], br[1], br[2], br[3]])?; } } + // Put any scalars at the back +// let mut scalars = Vec::new(); +// let mut i = 0; +// while i < args.len() { +// if graph.nodes[args[i]].shape.order() == 0 { +// scalars.push(args.remove(i)); +// } else { +// i += 1; +// } +// } +// args.append(&mut scalars); Ok(args) } diff --git a/src/ops/linalg.rs b/src/ops/linalg.rs index 5325fb8..5877018 100644 --- a/src/ops/linalg.rs +++ b/src/ops/linalg.rs @@ -6,9 +6,9 @@ use api::*; #[derive(Debug, Clone)] -pub struct MatrixMul {} +pub struct MatMul {} -impl Operator for MatrixMul { +impl Operator for MatMul { #[allow(unused_variables, unused_mut)] fn reverse_diff(&self, g: &mut Graph, x: usize, dx: usize, flow_tree: &Vec) -> Result> { @@ -16,8 +16,8 @@ impl Operator for MatrixMul { if anc.len() == 2 { let mut res = Vec::new(); if flow_tree[anc[0]] { - let dx_transpose = ids::reorder(g, dx, None)?; - res.push((anc[0], ids::mat_mul(g, anc[1], dx_transpose)?)); + let anc_transpose = ids::reorder(g, anc[1], None)?; + res.push((anc[0], ids::mat_mul(g, dx, anc_transpose)?)); } if flow_tree[anc[1]] { let anc_transpose = ids::reorder(g, anc[0], None)?; @@ -60,7 +60,7 @@ impl Operator for MatrixMul { fn get_meta(&self) -> &OperatorMetaData { static MATRIX_MUL: OperatorMetaData = OperatorMetaData{ - name: "MatrixMul", + name: "MatMul", arity: Arity::Nary, num_outputs: 1, differential_parents: ::std::usize::MAX, diff --git a/src/ops/shape.rs b/src/ops/shape.rs index 8b8e892..41ee666 100644 --- a/src/ops/shape.rs +++ b/src/ops/shape.rs @@ -3,6 +3,7 @@ use primitives::*; use graph::*; use errors::*; use std::any::Any; +use std::collections::HashSet; #[derive(Debug, Clone)] @@ -50,7 +51,7 @@ impl Operator for TensorShape { id: 0, name: "".into(), ancestors: args.clone(), - children: Vec::new(), + children: HashSet::new(), op: self.clone_box(), data_type: self.get_data_type(g, &args), shape: self.get_shape(g, &args), diff --git a/src/ops/special.rs b/src/ops/special.rs index 8f2f167..5274373 100644 --- a/src/ops/special.rs +++ b/src/ops/special.rs @@ -5,61 +5,61 @@ use errors::*; use api::*; use std::any::Any; -//#[derive(Debug, Clone)] -//pub struct Update {} -// -//impl Operator for Update { -// #[allow(unused_variables, unused_mut)] -// fn reverse_diff(&self, g: &mut Graph, x: usize, dx: usize, flow_tree: &Vec) -// -> Result> { -// unimplemented!() -// } -// -// fn verify_args(&self, g: &mut Graph, args: Vec) -> Result> { -// let meta = self.get_meta(); -// let args = default::verify_args(meta, g, args)?; -// // Verify first argument is a Parameter -// if g.get_node(args[0])?.op.get_meta().name != "Parameter" { -// return Err(ErrorKind::InvalidArguments( -// String::new() + meta.name, args, -// "First argument must be a parameter.".into()).into()) -// } -// // Verify that the first argument does not already have an Update -// match g.op_map.get("Update").unwrap_or(&Vec::new()).iter().position(|&x| x == args[0]) { -// Some(_) => { -// let param_name = g.get_node(args[0])?.op.get_args().unwrap() -// .downcast::<(String, FundamentalType, Shape)>().unwrap().0; -// Err(ErrorKind::InvalidArguments( -// String::new() + meta.name, args, -// format!("The parameter '{}' already has an update.", param_name)).into()) -// }, -// None => Ok(args) -// } -// -// } -// -// fn clone_box(&self) -> Box { -// Box::new(self.clone()) -// } -// -// fn get_meta(&self) -> &OperatorMetaData { -// static UPDATE: OperatorMetaData = OperatorMetaData{ -// name: "Update", -// arity: Arity::Binary, -// num_outputs: 0, -// differential_parents: 0, -// ordered_parents: true, -// elementwise: true, -// type_preserving: false, -// reduction: false, -// differentiable: false, -// scalar_output: false, -// shape_operator: false, -// fixed_output_type: None, -// }; -// &UPDATE -// } -//} +#[derive(Debug, Clone)] +pub struct Update {} + +impl Operator for Update { + #[allow(unused_variables, unused_mut)] + fn reverse_diff(&self, g: &mut Graph, x: usize, dx: usize, flow_tree: &Vec) + -> Result> { + unimplemented!() + } + + fn verify_args(&self, g: &mut Graph, args: Vec) -> Result> { + let meta = self.get_meta(); + let args = default::verify_args(meta, g, args)?; + // Verify first argument is a Parameter + if g.get_node(args[0])?.op.get_meta().name != "Parameter" { + return Err(ErrorKind::InvalidArguments( + String::new() + meta.name, args, + "First argument must be a parameter.".into()).into()) + } + // Verify that the first argument does not already have an Update + match g.op_map.get("Update").unwrap_or(&Vec::new()).iter().position(|&x| x == args[0]) { + Some(_) => { + let param_name = g.get_node(args[0])?.op.get_args().unwrap() + .downcast::<(String, FundamentalType, Shape)>().unwrap().0; + Err(ErrorKind::InvalidArguments( + String::new() + meta.name, args, + format!("The parameter '{}' already has an update.", param_name)).into()) + }, + None => Ok(args) + } + + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn get_meta(&self) -> &OperatorMetaData { + static UPDATE: OperatorMetaData = OperatorMetaData{ + name: "Update", + arity: Arity::Binary, + num_outputs: 0, + differential_parents: 0, + ordered_parents: true, + elementwise: true, + type_preserving: false, + reduction: false, + differentiable: false, + scalar_output: false, + shape_operator: false, + fixed_output_type: None, + }; + &UPDATE + } +} #[derive(Debug, Clone)] diff --git a/src/primitives.rs b/src/primitives.rs index 1abae5b..114991d 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -25,10 +25,10 @@ impl ::std::fmt::Display for FundamentalType { /// Variable storage precisions #[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Ord)] pub enum Precision { - P8 = 0, - P16 = 1, - P32 = 2, - P64 = 3, + P8 = 1, + P16 = 2, + P32 = 4, + P64 = 8, } impl ::std::fmt::Display for Precision { @@ -42,7 +42,6 @@ impl ::std::fmt::Display for Precision { } } - /// Operator arity (Number of arguments) #[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Ord)] pub enum Arity { @@ -223,6 +222,10 @@ impl Shape { } } + pub fn elements(&self) -> SymInt { + &self.0 * &self.1 * &self.2 * &self.3 + } + pub fn get(&self, axis: Axis) -> &SymInt { match axis { Axis::Axis0 => &self.0, @@ -297,3 +300,38 @@ impl ::std::fmt::Display for Policy { } } +//#[derive(Debug, Clone)] +//pub struct ScopedName { +// pub scoped_name: Vec, +// pub scope_delimiter: String +//} +// +//impl PartialEq for ScopedName { +// fn eq(&self, other: &ScopedName) -> bool { +// if self.scoped_name.len() != other.scoped_name.len() { +// return false +// } +// for (name1, name2) in self.scoped_name.iter().zip(other.scoped_name.iter()) { +// if name1 != name2 { +// return false +// } +// } +// return true +// } +//} +// +//impl Eq for ScopedName {} +// +//impl ::std::hash::Hash for ScopedName { +// fn hash(&self, state: &mut H) { +// self.scoped_name.join(&self.scope_delimiter).hash(state); +// } +//} +// +//impl ::std::fmt::Display for ScopedName { +// fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { +// write!(f, "{}", self.scoped_name.join(&self.scope_delimiter)) +// } +//} + + diff --git a/src/props.rs b/src/props.rs index 6d2f14a..2325a88 100644 --- a/src/props.rs +++ b/src/props.rs @@ -36,6 +36,7 @@ pub struct GraphProperties { impl Default for GraphProperties { fn default() -> Self { + // Todo this should be loaded from the environment GraphProperties { http_proxy: None, scope_delimiter: "::".into(), diff --git a/templates/kernels/store.tera b/templates/kernels/store.tera new file mode 100644 index 0000000..d617e63 --- /dev/null +++ b/templates/kernels/store.tera @@ -0,0 +1,14 @@ +// Typedefs in the form TYPE_PRECISION +typedef short int_16; +typedef int int_32; +typedef long int_64; +typedef ushort uint_16; +typedef uint uint_32; +typedef ulong uint_64; +typedef half float_16; +typedef float float_32; +typedef double float_64; + +__kernel void store(__global {{b_type}}* buffer, {{c_type}} value) { + buffer[get_global_id(0)] = value; +}