Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numba acceleration #14

Open
dsfulf opened this issue Oct 14, 2019 · 2 comments
Open

numba acceleration #14

dsfulf opened this issue Oct 14, 2019 · 2 comments

Comments

@dsfulf
Copy link

dsfulf commented Oct 14, 2019

Given that a lot of geostatspy is written in pure Python, I would like to offer the suggestion that some minor refactoring be performed to enable adding numba @njit decorators to compute-intensive functions.

For example, taking the geostatspy.varmapv function, we can split the mainpulation of the pandas.DataFrame object from the numerical code:

def varmapv(df,xcol,ycol,vcol,tmin,tmax,nxlag,nylag,dxlag,dylag,minnp,isill): 

    # Parameters - consistent with original GSLIB    
    # df - DataFrame with the spatial data, xcol, ycol, vcol coordinates and property columns
    # tmin, tmax - property trimming limits
    # xlag, xltol - lag distance and lag distance tolerance
    # nlag - number of lags to calculate
    # azm, atol - azimuth and azimuth tolerance
    # bandwh - horizontal bandwidth / maximum distance offset orthogonal to azimuth
    # isill - 1 for standardize sill

    # Load the data

    df_extract = df.loc[(df[vcol] >= tmin) & (df[vcol] <= tmax)]    # trim values outside tmin and tmax
    nd = len(df_extract)
    x = df_extract[xcol].values
    y = df_extract[ycol].values
    vr = df_extract[vcol].values
    
    # Summary statistics for the data after trimming
   ...

After refactoring:

from numba import njit

def varmapv(df, xcol, ycol, vcol, tmin, tmax, nxlag, nylag, dxlag, dylag, minnp, isill): 

    # Parameters - consistent with original GSLIB    
    # df - DataFrame with the spatial data, xcol, ycol, vcol coordinates and property columns
    # tmin, tmax - property trimming limits
    # xlag, xltol - lag distance and lag distance tolerance
    # nlag - number of lags to calculate
    # azm, atol - azimuth and azimuth tolerance
    # bandwh - horizontal bandwidth / maximum distance offset orthogonal to azimuth
    # isill - 1 for standardize sill

    # Load the data
    df_extract = df.loc[(df[vcol] >= tmin) & (df[vcol] <= tmax)]    # trim values outside tmin and tmax
    nd = len(df_extract)
    x = df_extract[xcol].values
    y = df_extract[ycol].values
    vr = df_extract[vcol].values
    
    return _varmapv(nd, x, y, vr, nxlag, nylag, dxlag, dylag, minnp, isill)

@njit
def _varmapv(nd, x, y, vr, nxlag, nylag, dxlag, dylag, minnp, isill):

    
    # Summary statistics for the data after trimming
    ...

Timing of the current implementation is 580 ms on my machine, while the numba decorated version is is 2.02 ms

For scaling up to several thousand data points, a factor of over 100x is considerable!

Further optimization can be performed for functions that are parallelizable, letting numba release the GIL and optimize the function for multi-processing / multi-threading.

@GeostatsGuy
Copy link
Owner

Thank you dsfulf. Now that I have some more time (last term was crazy), I'll ask my graduate students to implement this.

Thank you for contributing to geostatspy, Michael

@PauloCarvalhoRJ
Copy link

It is certainly worth it. I numbafied parts of my Python code and the speedup is indeed in the order of the hundreds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants