Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated ae1svm comments w/ parameters descriptions #595

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 175 additions & 3 deletions pyod/models/ae1svm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""Using AE-1SVM with Outlier Detection (PyTorch)
Source: https://arxiv.org/pdf/1804.04888
There is another implementation of this model by Minh Nghia: https://github.com/minh-nghia/AE-1SVM (Tensorflow)
"""
# Author: Zhuo Xiao <[email protected]>


import numpy as np
import torch
from sklearn.utils import check_array
Expand All @@ -15,22 +15,56 @@
from ..utils.stat_models import pairwise_distances_no_broadcast
from ..utils.torch_utility import get_activation_by_name, TorchDataset


class InnerAE1SVM(nn.Module):
"""Internal model combining an Autoencoder and One-class SVM.

Parameters
----------
n_features : int
Number of features in the input data.

encoding_dim : int
Dimension of the encoded representation.

rff_dim : int
Dimension of the random Fourier features.

sigma : float, optional (default=1.0)
Scaling factor for the random Fourier features.

hidden_neurons : tuple of int, optional (default=(128, 64))
Number of neurons in the hidden layers.

dropout_rate : float, optional (default=0.2)
Dropout rate for regularization.

batch_norm : bool, optional (default=True)
Whether to use batch normalization.

hidden_activation : str, optional (default='relu')
Activation function for hidden layers.
"""

def __init__(self, n_features, encoding_dim, rff_dim, sigma=1.0,
hidden_neurons=(128, 64),
dropout_rate=0.2, batch_norm=True, hidden_activation='relu'):
super(InnerAE1SVM, self).__init__()

# Encoder: Sequential model consisting of linear, batch norm, activation, and dropout layers.
self.encoder = nn.Sequential()
# Decoder: Sequential model to reconstruct the input from the encoded representation.
self.decoder = nn.Sequential()
# Random Fourier Features layer for approximating the kernel function.
self.rff = RandomFourierFeatures(encoding_dim, rff_dim, sigma)
# Parameters for the SVM.
self.svm_weights = nn.Parameter(torch.randn(rff_dim))
self.svm_bias = nn.Parameter(torch.randn(1))

# Activation function
activation = get_activation_by_name(hidden_activation)
layers_neurons_encoder = [n_features, *hidden_neurons, encoding_dim]

# Build encoder
for idx in range(len(layers_neurons_encoder) - 1):
self.encoder.add_module(f"linear{idx}",
nn.Linear(layers_neurons_encoder[idx],
Expand All @@ -43,6 +77,7 @@ def __init__(self, n_features, encoding_dim, rff_dim, sigma=1.0,

layers_neurons_decoder = layers_neurons_encoder[::-1]

# Build decoder
for idx in range(len(layers_neurons_decoder) - 1):
self.decoder.add_module(f"linear{idx}",
nn.Linear(layers_neurons_decoder[idx],
Expand All @@ -56,28 +91,128 @@ def __init__(self, n_features, encoding_dim, rff_dim, sigma=1.0,
nn.Dropout(dropout_rate))

def forward(self, x):
"""Forward pass through the model.

Parameters
----------
x : torch.Tensor
Input data.

Returns
-------
tuple of torch.Tensor
Reconstructed input and random Fourier features.
"""
x = self.encoder(x)
rff_features = self.rff(x)
x = self.decoder(x)
return x, rff_features

def svm_decision_function(self, rff_features):
"""Compute the SVM decision function.

Parameters
----------
rff_features : torch.Tensor
Random Fourier features.

Returns
-------
torch.Tensor
SVM decision scores.
"""
return torch.matmul(rff_features, self.svm_weights) + self.svm_bias


class RandomFourierFeatures(nn.Module):
"""Layer for computing random Fourier features.

Parameters
----------
input_dim : int
Dimension of the input data.

output_dim : int
Dimension of the output features.

sigma : float, optional (default=1.0)
Scaling factor for the random Fourier features.
"""

def __init__(self, input_dim, output_dim, sigma=1.0):
super(RandomFourierFeatures, self).__init__()
self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * sigma)
self.bias = nn.Parameter(torch.randn(output_dim) * 2 * np.pi)

def forward(self, x):
"""Forward pass to compute random Fourier features.

Parameters
----------
x : torch.Tensor
Input data.

Returns
-------
torch.Tensor
Random Fourier features.
"""
x = torch.matmul(x, self.weights) + self.bias
return torch.cos(x)


class AE1SVM(BaseDetector):
"""Auto Encoder with One-class SVM for anomaly detection."""
"""Auto Encoder with One-class SVM for anomaly detection.

Note: self.device is needed or all tensors may not be on the same device (if device w/ GPU running)

Parameters
----------
hidden_neurons : list, optional (default=[64, 32])
Number of neurons in each hidden layer.

hidden_activation : str, optional (default='relu')
Activation function for the hidden layers.

batch_norm : bool, optional (default=True)
Whether to use batch normalization.

learning_rate : float, optional (default=1e-3)
Learning rate for training the model.

epochs : int, optional (default=50)
Number of training epochs.

batch_size : int, optional (default=32)
Size of each training batch.

dropout_rate : float, optional (default=0.2)
Dropout rate for regularization.

weight_decay : float, optional (default=1e-5)
Weight decay (L2 penalty) for the optimizer.

preprocessing : bool, optional (default=True)
Whether to apply standard scaling to the input data.

loss_fn : callable, optional (default=torch.nn.MSELoss)
Loss function to use for reconstruction loss.

contamination : float, optional (default=0.1)
Proportion of outliers in the data.

alpha : float, optional (default=1.0)
Weight for the reconstruction loss in the final loss computation.

sigma : float, optional (default=1.0)
Scaling factor for the random Fourier features.

nu : float, optional (default=0.1)
Parameter for the SVM loss.

kernel_approx_features : int, optional (default=1000)
Number of random Fourier features to approximate the kernel.
"""

def __init__(self, hidden_neurons=None, hidden_activation='relu',
batch_norm=True, learning_rate=1e-3, epochs=50, batch_size=32,
Expand Down Expand Up @@ -109,6 +244,21 @@ def __init__(self, hidden_neurons=None, hidden_activation='relu',
self.kernel_approx_features = kernel_approx_features

def fit(self, X, y=None):
"""Fit the model to the data.

Parameters
----------
X : numpy.ndarray
Input data.

y : None
Ignored, present for API consistency by convention.

Returns
-------
self : object
Fitted estimator.
"""
X = check_array(X)
self._set_n_classes(y)

Expand Down Expand Up @@ -139,11 +289,21 @@ def fit(self, X, y=None):
else:
raise ValueError('Training failed, no valid model state found')

if not isinstance(X, np.ndarray):
X = np.array(X)

self.decision_scores_ = self.decision_function(X)
self._process_decision_scores()
return self

def _train_autoencoder(self, train_loader):
"""Train the autoencoder.

Parameters
----------
train_loader : torch.utils.data.DataLoader
DataLoader for the training data.
"""
optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
Expand Down Expand Up @@ -173,6 +333,18 @@ def _train_autoencoder(self, train_loader):
self.best_model_dict = self.model.state_dict()

def decision_function(self, X):
"""Predict raw anomaly score of X using the fitted detector.

Parameters
----------
X : numpy.ndarray
The input samples.

Returns
-------
numpy.ndarray
The anomaly score of the input samples.
"""
check_is_fitted(self, ['model', 'best_model_dict'])
X = check_array(X)
dataset = TorchDataset(X=X, mean=self.mean,
Expand Down
Loading