Skip to content

Commit

Permalink
Rename get to get_mut, add new get to LockedArray, lifetimes to D::Da…
Browse files Browse the repository at this point in the history
…ta, D::Wrap
  • Loading branch information
elftausend committed Oct 23, 2024
1 parent d3bd64e commit 92e571c
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 98 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

# default = ["cpu"]
default = ["cpu"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
Expand Down
6 changes: 3 additions & 3 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mod num;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Buffer<'a, T: Unit = f32, D: Device = CPU<Base>, S: Shape = ()> {
/// the type of pointer
pub(crate) data: D::Data<T, S>,
pub(crate) data: D::Data<'a, T, S>,
/// A reference to the corresponding device. Mainly used for operations without a device parameter.
#[cfg_attr(feature = "serde", serde(skip))]
pub(crate) device: Option<&'a D>,
Expand Down Expand Up @@ -273,7 +273,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
#[inline]
pub fn to_deviceless<'b>(self) -> Buffer<'b, T, D, S>
where
D::Data<T, S>: Default,
D::Data<'b, T, S>: Default,
{
if let Some(device) = self.device {
if self.data.flag() != AllocFlag::None {
Expand All @@ -298,7 +298,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
}

#[inline]
pub fn data(&self) -> &D::Data<T, S> {
pub fn data(&self) -> &D::Data<'a, T, S> {
&self.data
}

Expand Down
2 changes: 1 addition & 1 deletion src/buffer/impl_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ where
T: Unit + 'static,
S: Shape,
D: WriteBuf<T, S> + Device + Retriever<T, S>,
<CPU<Mods> as Device>::Data<T, S>: core::ops::Deref<Target = [T]>,
<CPU<Mods> as Device>::Data<'b, T, S>: core::ops::Deref<Target = [T]>,
{
fn from((device, buf): (&'a D, Buffer<'b, T, CPU<Mods>, S>)) -> Self {
let mut out = device.retrieve(buf.len(), &buf).unwrap();
Expand Down
30 changes: 15 additions & 15 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<T> ShallowCopy for Num<T> {
}

impl Device for () {
type Data<T: Unit, S: crate::Shape> = Self::Base<T, S>;
type Data<'a, T: Unit, S: crate::Shape> = Self::Base<T, S>;
type Base<T: Unit, S: crate::Shape> = Num<T>;

type Error = Infallible;
Expand All @@ -61,28 +61,28 @@ impl Device for () {
}

#[inline(always)]
fn base_to_data<T: Unit, S: crate::Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
fn base_to_data<'a, T: Unit, S: crate::Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
base
}

#[inline(always)]
fn wrap_to_data<T: Unit, S: crate::Shape>(
fn wrap_to_data<'a, T: Unit, S: crate::Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

#[inline(always)]
fn data_as_wrap<T: Unit, S: crate::Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: crate::Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

fn data_as_wrap_mut<T: Unit, S: crate::Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: crate::Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}
}
Expand All @@ -107,20 +107,20 @@ impl<T: Unit + Default> Alloc<T> for () {
}

impl WrappedData for () {
type Wrap<T: Unit, Base: crate::HasId + crate::PtrType> = Base;
type Wrap<'a, T: Unit, Base: crate::HasId + crate::PtrType> = Base;

#[inline]
fn wrap_in_base<T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
fn wrap_in_base<'a, T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
base
}

#[inline]
fn wrapped_as_base<T: Unit, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: HasId + PtrType>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: HasId + PtrType>(wrap: &'b mut Self::Wrap<'a, T, Base>) -> &'b mut Base {
wrap
}
}
Expand Down
37 changes: 26 additions & 11 deletions src/cache/locking/locked_array.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::cell::{RefCell, RefMut};
use core::cell::{Ref, RefCell, RefMut};

use crate::cow_mut::CowMutCell;

Expand Down Expand Up @@ -32,6 +32,21 @@ impl<T, const N: usize> LockedArray<T, N> {
}

pub fn get<'a>(&'a self, id: usize) -> State<Guard<'a, T>> {
match self.data[id].try_borrow() {
Ok(data) => {
if data.is_none() {
return State::Err(LockInfo::None);
}
return State::Ok(Guard::new(Some(CowMutCell::Borrowed(Ref::map(
data,
|data| data.as_ref().unwrap(),
)))));
}
Err(_) => return State::Err(LockInfo::Locked),
}
}

pub fn get_mut<'a>(&'a self, id: usize) -> State<Guard<'a, T>> {
match self.data[id].try_borrow_mut() {
Ok(data) => {
if data.is_none() {
Expand Down Expand Up @@ -60,11 +75,11 @@ mod tests {
locked_array.set(2, vec![2]);
locked_array.set(3, vec![3]);

let mut data0 = locked_array.get(0).unwrap();
let mut data0 = locked_array.get_mut(0).unwrap();
assert_eq!(data0.as_slice(), [0, 0]);
data0[0] = 1;
assert_eq!(data0.as_slice(), [1, 0]);
let mut data1 = locked_array.get(1).unwrap();
let mut data1 = locked_array.get_mut(1).unwrap();
assert_eq!(data1.as_slice(), [1]);
data1.push(2);
assert_eq!(data1.as_slice(), [1, 2]);
Expand All @@ -84,11 +99,11 @@ mod tests {
fn test_get_not_set() {
let locked_array = LockedArray::<Vec<i32>>::new();
{
let _d = locked_array.get(1);
assert!(locked_array.get(1).is_err());
let _d = locked_array.get_mut(1);
assert!(locked_array.get_mut(1).is_err());
}
let _ = locked_array.get(1);
assert!(locked_array.get(1).is_err());
let _ = locked_array.get_mut(1);
assert!(locked_array.get_mut(1).is_err());
}

#[cfg(feature = "std")]
Expand All @@ -97,10 +112,10 @@ mod tests {
let locked_array = LockedArray::<Vec<i32>>::new();
locked_array.set(1, vec![10]);
{
let _d = locked_array.get(1);
assert!(locked_array.get(1).is_err());
let _d = locked_array.get_mut(1);
assert!(locked_array.get_mut(1).is_err());
}
let _ = locked_array.get(1);
assert!(locked_array.get(1).is_ok());
let _ = locked_array.get_mut(1);
assert!(locked_array.get_mut(1).is_ok());
}
}
22 changes: 11 additions & 11 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::{Buffer, HasId, OnDropBuffer, Parents, PtrType, Shape, Unit};
/// The `Device` trait is the main trait for all compute devices.
pub trait Device: OnDropBuffer + Sized {
type Base<T: Unit, S: Shape>: HasId + PtrType;
type Data<T: Unit, S: Shape>: HasId + PtrType;
type Data<'a, T: Unit, S: Shape>: HasId + PtrType;

type Error;

Expand All @@ -58,16 +58,16 @@ pub trait Device: OnDropBuffer + Sized {

// add default impl if GAT default go stable
// FIXME: probably a better way to realize these
fn base_to_data<T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S>;
fn wrap_to_data<T: Unit, S: Shape>(
fn base_to_data<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S>;
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S>;
fn data_as_wrap<T: Unit, S: Shape>(data: &Self::Data<T, S>)
-> &Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap_mut<T: Unit, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>>;
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S>;
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(data: &'b Self::Data<'a, T, S>)
-> &'b Self::Wrap<'a, T, Self::Base<T, S>>;
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>>;

/// Creates a new [`Buffer`] using `A`, typically an array type.
///
Expand Down Expand Up @@ -164,7 +164,7 @@ pub trait Retriever<T: Unit, S: Shape = ()>: Device {
#[macro_export]
macro_rules! impl_retriever {
($device:ident, $($trait_bounds:tt)*) => {
impl<T: $( $trait_bounds )* + $crate::Unit, Mods: $crate::Retrieve<Self, T, S>, S: $crate::Shape> $crate::Retriever<T, S> for $device<Mods> {
impl<'a, T: $( $trait_bounds )* + $crate::Unit, Mods: $crate::Retrieve<'a, Self, T, S>, S: $crate::Shape> $crate::Retriever<T, S> for $device<Mods> {
#[inline]
fn retrieve<const NUM_PARENTS: usize>(
&self,
Expand Down
22 changes: 11 additions & 11 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,37 @@ impl<Mods> IsCPU for CPU<Mods> {}
impl<Mods: OnDropBuffer> Device for CPU<Mods> {
type Error = Infallible;
type Base<T: Unit, S: Shape> = CPUPtr<T>;
type Data<T: Unit, S: Shape> = Self::Wrap<T, Self::Base<T, S>>;
type Data<'a, T: Unit, S: Shape> = Self::Wrap<'a, T, Self::Base<T, S>>;
// type WrappedData<T, S: Shape> = ;

fn new() -> Result<Self, Self::Error> {
todo!()
}

#[inline(always)]
fn base_to_data<T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
fn base_to_data<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
self.wrap_in_base(base)
}

#[inline(always)]
fn wrap_to_data<T: Unit, S: Shape>(
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

#[inline(always)]
fn data_as_wrap<T: Unit, S: Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

#[inline(always)]
fn data_as_wrap_mut<T: Unit, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

Expand Down
4 changes: 2 additions & 2 deletions src/devices/cpu/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use crate::{
pass_down_add_operation!(CPU);
pass_down_exec_now!(CPU);

impl<Mods, T, D, S> ApplyFunction<T, S, D> for CPU<Mods>
impl<'a, Mods, T, D, S> ApplyFunction<T, S, D> for CPU<Mods>
where
Mods: Retrieve<Self, T, S> + AddOperation + SetOpHint<T> + 'static,
Mods: Retrieve<'a, Self, T, S> + AddOperation + SetOpHint<T> + 'static,
T: Unit + Copy + Default + ToVal + 'static,
D: Device + 'static,
D::Base<T, S>: Deref<Target = [T]>,
Expand Down
2 changes: 1 addition & 1 deletion src/devices/untyped/untyped_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl OnDropBuffer for Untyped {}
impl<'dev, T: Unit, D: Device, S: Shape> OnNewBuffer<'dev, T, D, S> for Untyped {}

impl WrappedData for Untyped {
type Wrap<T: Unit, Base: HasId + PtrType> = Base;
type Wrap<'a, T: Unit, Base: HasId + PtrType> = Base;

#[inline]
fn wrap_in_base<T: Unit, Base: crate::HasId + crate::PtrType>(
Expand Down
4 changes: 2 additions & 2 deletions src/exec_on_cpu/cl_may_unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn cpu_exec_unary_may_unified<'a, T, F, Mods>(
where
T: Unit + Clone + Default + 'static,
F: for<'b> Fn(&'b CachedCPU, &Buffer<'_, T, CachedCPU>) -> Buffer<'b, T, CachedCPU>,
Mods: OnDropBuffer + Retrieve<OpenCL<Mods>, T> + UnifiedMemChain<OpenCL<Mods>> + 'static,
Mods: OnDropBuffer + Retrieve<'a, OpenCL<Mods>, T> + UnifiedMemChain<OpenCL<Mods>> + 'static,
{
let cpu = &device.cpu;
crate::cl_cpu_exec_unified!(device, cpu, x; f(&cpu, &x))
Expand Down Expand Up @@ -61,7 +61,7 @@ pub fn cpu_exec_binary_may_unified<'a, T, F, Mods>(
where
T: Unit + Clone + Default + 'static,
F: for<'b> Fn(&'b CachedCPU, &CpuBuf<'_, T>, &CpuBuf<'_, T>) -> CpuBuf<'b, T>,
Mods: UnifiedMemChain<OpenCL<Mods>> + Retrieve<OpenCL<Mods>, T> + 'static,
Mods: UnifiedMemChain<OpenCL<Mods>> + Retrieve<'a, OpenCL<Mods>, T> + 'static,
{
let cpu = &device.cpu;
crate::cl_cpu_exec_unified!(device, cpu, lhs, rhs; f(&cpu, &lhs, &rhs))
Expand Down
6 changes: 3 additions & 3 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ pub trait Feature: OnDropBuffer {}
// how to fix this:
// add retrieved buffer to no grads pool at the end of the chain (at device level (Retriever trait))
// => "generator", "actor"
pub trait Retrieve<D, T: Unit, S: Shape = ()>: OnDropBuffer {
pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer {
// "generator"
#[track_caller]
unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
&'a self,
device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
) -> crate::Result<Self::Wrap<T, D::Base<T, S>>>
) -> crate::Result<Self::Wrap<'a, T, D::Base<T, S>>>
where
S: Shape,
D: Device + Alloc<T>;
Expand Down
2 changes: 1 addition & 1 deletion src/modules/autograd/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Gradients {
T: Unit + 'static,
S: Shape,
D: Alloc<T> + ZeroGrad<T> + 'static,
D::Data<T, S>: HasId,
D::Data<'a, T, S>: HasId,
{
self.get_ref(buf.device(), buf.id())
}
Expand Down
18 changes: 9 additions & 9 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ pub struct ReqGradWrapper<Data, T> {
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<T: Unit, Base: crate::HasId + crate::PtrType> = ReqGradWrapper<Mods::Wrap<T, Base>, T>;
type Wrap<'a, T: Unit, Base: crate::HasId + crate::PtrType> = ReqGradWrapper<Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<T: Unit, Base: crate::HasId + crate::PtrType>(
fn wrap_in_base<'a, T: Unit, Base: crate::HasId + crate::PtrType>(
&self,
base: Base,
) -> Self::Wrap<T, Base> {
) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
Expand All @@ -26,16 +26,16 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
}

#[inline]
fn wrapped_as_base<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &Self::Wrap<T, Base>,
) -> &Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &'b Self::Wrap<'a, T, Base>,
) -> &'b Base {
Mods::wrapped_as_base(&wrap.data)
}

#[inline]
fn wrapped_as_base_mut<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
fn wrapped_as_base_mut<'a, 'b, T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &'b mut Self::Wrap<'a, T, Base>,
) -> &'b mut Base {
Mods::wrapped_as_base_mut(&mut wrap.data)
}
}
Expand Down
Loading

0 comments on commit 92e571c

Please sign in to comment.