Skip to content

Commit

Permalink
Added quantile regression as possible final MLP Layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Masurel committed Jul 6, 2015
1 parent a303ec1 commit caef584
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
43 changes: 43 additions & 0 deletions pylearn2/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3824,6 +3824,49 @@ def L1WeightDecay(*args, **kwargs):
return _L1WD(*args, **kwargs)


class QuantileRegression(Linear):
"""
A linear layer for quantile regression.
A QuantileRegression (http://en.wikipedia.org/wiki/Quantile_regression)
is a linear layer that uses a specific cost that makes it possible to get
an estimator of a specific percentile of a posterior distribution.
Parameters
----------
layer_name: str
The layer name
percentile: float (0 < percentile < 1)
Percentile being estimated.
"""
def __init__(self,
layer_name,
percentile=0.2,
**kargs):
Linear.__init__(self, 1, layer_name, **kargs)
self.percentile = percentile

@wraps(Layer.get_layer_monitoring_channels)
def get_layer_monitoring_channels(self,
state_below=None,
state=None,
targets=None):
rval = Linear.get_layer_monitoring_channels(
self,
state_below,
state,
targets)
assert isinstance(rval, OrderedDict)
if targets:
rval['qcost'] = (T.abs_(targets - state) * (0.5 + (self.percentile - 0.5) * T.sgn(targets - state) )).mean()
return rval

@wraps(Layer.cost_matrix)
def cost_matrix(self, Y, Y_hat):
return T.abs_(Y - Y_hat) * (0.5 + (self.percentile - 0.5) * T.sgn(Y - Y_hat) )


class LinearGaussian(Linear):

"""
Expand Down
36 changes: 36 additions & 0 deletions pylearn2/models/tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,3 +1389,39 @@ def test_pooling_with_anon_variable():
image_shape=im_shp, try_dnn=False)
pool_1 = mean_pool(X_sym, pool_shape=shp, pool_stride=strd,
image_shape=im_shp)



def test_quantile_regression():
"""
Create a VectorSpacesDataset with two inputs (features0 and features1)
and train an MLP which takes both inputs for 1 epoch.
"""
np.random.seed(2)
nb_rows = 1000
X = np.random.normal(size=(nb_rows, 2)).astype(theano.config.floatX)
noise = np.random.rand(nb_rows, 1) # X[:, 0:1] *
coeffs = np.array([[3.], [4.]])
y_0 = np.dot(X, coeffs)
y = y_0 + noise
dataset = DenseDesignMatrix(X=X, y=y)
for percentile in [0.22, 0.5, 0.65]:
mlp = MLP(
nvis=2,
layers=[
QuantileRegression('quantile_regression_layer',
init_bias=0.0,
percentile=percentile,
irange=0.1)
]
)
train = Train(dataset, mlp, SGD(0.05, batch_size=100))
train.algorithm.termination_criterion = EpochCounter(100)
train.main_loop()
inputs = mlp.get_input_space().make_theano_batch()
outputs = mlp.fprop(inputs)
y_ = theano.function([inputs], outputs, allow_input_downcast=True)(X)
layers = mlp.layers
layer = layers[0]
assert np.allclose(layers[0].get_weights(), coeffs, rtol=0.05)
assert np.allclose(layers[0].get_biases(), np.array(percentile), rtol=0.05)

0 comments on commit caef584

Please sign in to comment.