Skip to content

Commit

Permalink
First working example with Arrayfire backend - #3 #4
Browse files Browse the repository at this point in the history
  • Loading branch information
botev committed Feb 14, 2017
1 parent 0a69483 commit 3d8b1a3
Show file tree
Hide file tree
Showing 22 changed files with 819 additions and 240 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ slog = {version = "1.2", features = ["max_level_trace"] }
slog-term = "1.3.5"
tera = "0.6.2"
ocl = "0.12.0"
arrayfire = "3.4.1"


[replace]
"ocl-core:0.3.2" = { git = "https://github.com/cogciprocate/ocl-core" }
"ocl:0.12.0" = { git = "https://github.com/cogciprocate/ocl/"}

[lib]
name = "gir"
Expand Down
5 changes: 3 additions & 2 deletions src/api/exprs/special.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ops::interface::default::*;
use super::super::ids;
use std::ops::DerefMut;

pub fn overwrite_update<T: AsRef<Expr>>(arg: T, upd: T) -> Result<()> {
pub fn overwrite_update<T: AsRef<Expr>>(arg: T, upd: T) -> Result<bool> {
let arg = arg.as_ref();
let upd = upd.as_ref();
same_graph_2(arg, upd)?;
Expand All @@ -16,7 +16,8 @@ pub fn update<T: AsRef<Expr>>(arg: T, upd: T) -> Result<()> {
let arg = arg.as_ref();
let upd = upd.as_ref();
same_graph_2(arg, upd)?;
ids::update(arg.wrapper.get_mut().deref_mut(), arg.id, upd.id)
ids::update(arg.wrapper.get_mut().deref_mut(), arg.id, upd.id)?;
Ok(())
}

pub fn cast<T: AsRef<Expr>>(arg: T, data_type: FundamentalType) -> Result<Expr> {
Expand Down
2 changes: 1 addition & 1 deletion src/api/ids/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ use graph::*;
use errors::*;

pub fn mat_mul(graph: &mut Graph, arg0: usize, arg1: usize) -> Result<usize> {
Ok(graph.apply_op(Box::new(MatrixMul {}), vec![arg0, arg1])?)
Ok(graph.apply_op(Box::new(MatMul {}), vec![arg0, arg1])?)
}
39 changes: 28 additions & 11 deletions src/api/ids/special.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use graph::*;
use errors::*;
use api::ids;

pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> {
graph.updates.remove(&arg);
update(graph, arg, upd)
pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result<bool> {
let existed = remove_update(graph, arg)?;
update(graph, arg, upd)?;
Ok(existed)
// let candidates = graph.get_node(arg)?.children.clone();
// let mut old_update = None;
// for &c in &candidates {
Expand Down Expand Up @@ -36,22 +37,38 @@ pub fn overwrite_update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()>
// }
}

pub fn update(graph: &mut Graph, arg:usize, upd: usize) -> Result<()> {
pub fn update(graph: &mut Graph, arg:usize, upd: usize) -> Result<usize> {
// Verify first argument is a Parameter
if graph.get_node(arg)?.op.get_meta().name != "Parameter" {
return Err(ErrorKind::InvalidArguments(
"Update".into(), vec![arg, upd],
"First argument must be a parameter.".into()).into())
}
// Verify that the first argument does not already have an Update
if let Some(u) = graph.updates.get(&arg) {
let ref param_name = graph.nodes[arg].name;
return Err(ErrorKind::InvalidArguments(
"Update".into(), vec![arg, upd],
format!("The parameter '{}' already has an update - {}.", param_name, u)).into())
for &u in graph.op_map.get("Update").unwrap() {
if graph.nodes[u].ancestors[0] == arg {
let ref param_name = graph.nodes[arg].name;
return Err(ErrorKind::InvalidArguments(
"Update".into(), vec![arg, upd],
format!("The parameter '{}' already has an update - {}.", param_name, u)).into())
}
}
graph.apply_op(Box::new(Update {}), vec![arg, upd])
}

pub fn remove_update(graph: &mut Graph, arg:usize) -> Result<bool> {
let updates = graph.op_map.get("Update").unwrap();
for &u in updates {
if graph.nodes[u].ancestors[0] == arg {
let upd = graph.nodes[u].ancestors[1];
graph.nodes[arg].children.remove(&u);
graph.nodes[upd].children.remove(&u);
let op = Box::new(Cleared{});
graph.nodes[u] = op.apply_null();
return Ok(true)
}
}
graph.updates.insert(arg, upd);
Ok(())
Ok(false)
}

pub fn cast(graph: &mut Graph, arg: usize, data_type: FundamentalType) -> Result<usize> {
Expand Down
98 changes: 98 additions & 0 deletions src/backend/af/backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use primitives::*;
use graph::*;
use backend::*;

use std::collections::HashMap;
use std::rc::Rc;
use std::cell::{Ref, RefCell};
use std::io;
use af::function::AfFunction;
use arrayfire as af;
use arrayfire::Array;


/// For now this will support only single device
#[derive(Clone)]
pub struct AfBackend {
pub platform: ::arrayfire::Backend,
pub device: i32,
pub parameters: Rc<RefCell<HashMap<String, Array>>>,
pub precisions: BackendPrecisions
}


impl Default for AfBackend {
fn default() -> Self {
// Todo similar to GraphProps this should be loaded from system file
AfBackend {
platform: ::arrayfire::Backend::DEFAULT,
device: 0,
parameters: Rc::new(RefCell::new(HashMap::new())),
precisions: BackendPrecisions::default()
}
}
}

impl AfBackend {
pub fn get_param_value(&self, name: &str) -> Ref<Array> {
Ref::map(self.parameters.borrow(), |x| x.get(name).unwrap())
}

pub fn set_param_value(&mut self, name: &str, value:Array) -> Result<(), String> {
if let Some(v) = self.parameters.borrow().get(name) {
if v.dims() != value.dims() {
return Err(format!("The parameter {} has shape {}, \
but {} was passed to set_param_value.", name, v.dims(), value.dims()))
}
}
self.parameters.borrow_mut().insert(name.into(), value);
Ok(())
}
}

impl Backend<AfFunction> for AfBackend {
fn make_function(&self, gf: GraphFunction)
-> AfFunction {
let sym_input_shapes = gf.inputs.iter()
.map(|&id| gf.graph.nodes[id].shape.clone()).collect();
AfFunction {
initialized: false,
precisions: self.precisions,
gf: gf,
parameters: self.parameters.clone(),
sym_input_shapes: sym_input_shapes,
last_shapes: Vec::new(),
last_deduced: HashMap::new(),
expr_map: HashMap::new()
}
}

fn get_precisions(&self) -> &BackendPrecisions {
&self.precisions
}
fn set_precisions(&mut self, precisions: BackendPrecisions){
self.precisions = precisions;
}
fn info(&self, f:&mut io::Write) -> io::Result<()> {
writeln!(f, "Platform: {}", self.platform)?;
writeln!(f, "\tDevices: {}", af::device_count())
}

fn general_info(&self, f: &mut io::Write) -> io::Result<()> {
let backend = af::get_active_backend();
writeln!(f, "Arrayfire Backend General Information:")?;
writeln!(f, "==================================================")?;
for b in af::get_available_backends() {
writeln!(f, "Platform: {}", b)?;
writeln!(f, "\tDevices: {}", af::device_count())?;
af::set_backend(b);
af::info();
}
af::set_backend(backend);
Ok(())
}

fn print_info(&self) -> io::Result<()> {
Ok(::arrayfire::info())
}
}
Loading

0 comments on commit 3d8b1a3

Please sign in to comment.