Skip to content

Commit

Permalink
fix breaking dnn input formatting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Jan 1, 2025
1 parent aec881f commit 86e57ff
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,17 @@ def zscore_trial_segs(
# if train is not jagged, it gets converted completely to object
# np.ndarray. Hence, cannot exclusively use normed_train.loc
if isinstance(normed_train, pd.DataFrame):
normed_train = normed_train.loc
normed_train[:, train_nan_cols] = 0
normed_train.loc[:, train_nan_cols] = 0
else:
normed_train[:, train_nan_cols] = 0
else:
normed_train = np.empty_like(train)
for i, nsvstseg in enumerate(train):
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_train[i] = zscored

normed_rest_feats = []
Expand All @@ -230,16 +232,18 @@ def zscore_trial_segs(
if is_2D:
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
if isinstance(normed_feats, pd.DataFrame):
normed_feats = normed_feats.loc
normed_feats[:, train_nan_cols] = 0
normed_feats.loc[:, train_nan_cols] = 0
else:
normed_feats[:, train_nan_cols] = 0
normed_rest_feats.append(normed_feats)
else:
normed_feats = np.empty_like(feats)
for i, trialSegROI in enumerate(feats):
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_feats[i] = zscored
normed_rest_feats.append(normed_feats)

Expand Down

0 comments on commit 86e57ff

Please sign in to comment.