diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 78cdf89a..c65d21a9 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -825,11 +825,11 @@ defmodule Scholar.Metrics.Regression do y_quantile = Nx.broadcast(quantile(y_true, alpha), shape) denominator = mean_pinball_loss(y_true, y_quantile, alpha: alpha, multioutput: :raw_values) - nonzero_numerator = Nx.not_equal(numerator, 0) - nonzero_denominator = Nx.not_equal(denominator, 0) + nonzero_numerator = numerator != 0 + nonzero_denominator = denominator != 0 - valid_score = Nx.logical_and(nonzero_numerator, nonzero_denominator) - invalid_score = Nx.logical_and(nonzero_numerator, Nx.logical_not(nonzero_denominator)) + valid_score = nonzero_numerator and nonzero_denominator + invalid_score = nonzero_numerator and not nonzero_denominator output_scores = Nx.broadcast(1, {m})