-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
80 changed files
with
264,215 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
checkpoints/ | ||
results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Collecting causal-conv1d | ||
Downloading causal_conv1d-1.2.2.post1.tar.gz (7.2 kB) | ||
Preparing metadata (setup.py): started | ||
Preparing metadata (setup.py): finished with status 'error' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
from math import sqrt | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from transformers import MambaModel, AutoTokenizer,MambaConfig | ||
from layers.Embed import PatchEmbedding | ||
import transformers | ||
from layers.StandardNorm import Normalize | ||
|
||
transformers.logging.set_verbosity_error() | ||
|
||
|
||
class FlattenHead(nn.Module): | ||
def __init__(self, n_vars, nf, target_window, head_dropout=0): | ||
super().__init__() | ||
self.n_vars = n_vars | ||
self.flatten = nn.Flatten(start_dim=-2) | ||
self.linear = nn.Linear(nf, target_window) | ||
self.dropout = nn.Dropout(head_dropout) | ||
|
||
def forward(self, x): | ||
x = self.flatten(x) | ||
x = self.linear(x) | ||
x = self.dropout(x) | ||
return x | ||
|
||
|
||
class Model(nn.Module): | ||
|
||
def __init__(self, configs, patch_len=16, stride=8): | ||
super(Model, self).__init__() | ||
self.task_name = configs.task_name | ||
self.pred_len = configs.pred_len | ||
self.seq_len = configs.seq_len | ||
self.d_ff = configs.d_ff | ||
self.top_k = 5 | ||
self.d_llm = configs.llm_dim | ||
self.patch_len = configs.patch_len | ||
self.stride = configs.stride | ||
|
||
|
||
self.tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") | ||
self.mamba_config = MambaConfig.from_pretrained("state-spaces/mamba-130m-hf") | ||
|
||
self.mamba_config.num_hidden_layers = configs.llm_layers | ||
self.mamba_config.output_attentions = True | ||
self.mamba_config.output_hidden_states = True | ||
|
||
|
||
self.llm_model = MambaModel.from_pretrained( | ||
"state-spaces/mamba-130m-hf", | ||
config=self.mamba_config | ||
) | ||
|
||
|
||
if self.tokenizer.eos_token: | ||
self.tokenizer.pad_token = self.tokenizer.eos_token | ||
else: | ||
pad_token = '[PAD]' | ||
self.tokenizer.add_special_tokens({'pad_token': pad_token}) | ||
self.tokenizer.pad_token = pad_token | ||
|
||
for param in self.llm_model.parameters(): | ||
param.requires_grad = False | ||
|
||
if configs.prompt_domain: | ||
self.description = configs.content | ||
else: | ||
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.' | ||
|
||
self.dropout = nn.Dropout(configs.dropout) | ||
|
||
self.patch_embedding = PatchEmbedding( | ||
configs.d_model, self.patch_len, self.stride, configs.dropout) | ||
|
||
self.word_embeddings = self.llm_model.get_input_embeddings().weight | ||
self.vocab_size = self.word_embeddings.shape[0] | ||
self.num_tokens = 1000 | ||
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens) | ||
|
||
print("d_model: ", configs.d_model) | ||
print("n_heads: ", configs.n_heads) | ||
print("d_ff: ", self.d_ff) | ||
print("d_llm: ", self.d_llm) | ||
|
||
self.reprogramming_layer = ReprogrammingLayer(configs.d_model, configs.n_heads, self.d_ff, self.d_llm) | ||
|
||
self.patch_nums = int((configs.seq_len - self.patch_len) / self.stride + 2) | ||
|
||
self.head_nf = self.d_ff * self.patch_nums | ||
#print("self head_nf type: ", type(self.head_nf)) | ||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': | ||
self.output_projection = FlattenHead(configs.enc_in, self.head_nf, self.pred_len, | ||
head_dropout=configs.dropout) | ||
else: | ||
raise NotImplementedError | ||
|
||
self.normalize_layers = Normalize(configs.enc_in, affine=False) | ||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): | ||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': | ||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) | ||
return dec_out[:, -self.pred_len:, :] | ||
return None | ||
|
||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): | ||
|
||
x_enc = self.normalize_layers(x_enc, 'norm') | ||
|
||
B, T, N = x_enc.size() | ||
x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) | ||
|
||
min_values = torch.min(x_enc, dim=1)[0] | ||
max_values = torch.max(x_enc, dim=1)[0] | ||
medians = torch.median(x_enc, dim=1).values | ||
lags = self.calcute_lags(x_enc) | ||
trends = x_enc.diff(dim=1).sum(dim=1) | ||
|
||
prompt = [] | ||
for b in range(x_enc.shape[0]): | ||
min_values_str = str(min_values[b].tolist()[0]) | ||
max_values_str = str(max_values[b].tolist()[0]) | ||
median_values_str = str(medians[b].tolist()[0]) | ||
lags_values_str = str(lags[b].tolist()) | ||
prompt_ = ( | ||
f"<|start_prompt|>Dataset description: {self.description}" | ||
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; " | ||
"Input statistics: " | ||
f"min value {min_values_str}, " | ||
f"max value {max_values_str}, " | ||
f"median value {median_values_str}, " | ||
f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, " | ||
f"top 5 lags are : {lags_values_str}<|<end_prompt>|>" | ||
) | ||
|
||
prompt.append(prompt_) | ||
|
||
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous() | ||
|
||
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids | ||
prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim) | ||
|
||
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) | ||
|
||
x_enc = x_enc.permute(0, 2, 1).contiguous() | ||
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16)) | ||
|
||
#print("enc out shape: ", enc_out.shape) | ||
#print("source_embeddings input shape: ", source_embeddings.shape) | ||
|
||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) | ||
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1) | ||
|
||
modelOut = self.llm_model(inputs_embeds=llama_enc_out) | ||
|
||
#print("output Model: ", modelOut) | ||
|
||
dec_out = (modelOut.last_hidden_state) | ||
dec_out = dec_out[:, :, :self.d_ff] | ||
|
||
dec_out = torch.reshape( | ||
dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1])) | ||
dec_out = dec_out.permute(0, 1, 3, 2).contiguous() | ||
|
||
dec_out = dec_out.to(torch.bfloat16) | ||
|
||
dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:]) | ||
dec_out = dec_out.permute(0, 2, 1).contiguous() | ||
|
||
dec_out = self.normalize_layers(dec_out, 'denorm') | ||
|
||
return dec_out | ||
|
||
def calcute_lags(self, x_enc): | ||
q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1) | ||
k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1) | ||
res = q_fft * torch.conj(k_fft) | ||
corr = torch.fft.irfft(res, dim=-1) | ||
mean_value = torch.mean(corr, dim=1) | ||
_, lags = torch.topk(mean_value, self.top_k, dim=-1) | ||
return lags | ||
|
||
|
||
class ReprogrammingLayer(nn.Module): | ||
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1): | ||
super(ReprogrammingLayer, self).__init__() | ||
|
||
d_keys = d_keys or (d_model // n_heads) | ||
|
||
self.query_projection = nn.Linear(d_model, d_keys * n_heads) | ||
self.key_projection = nn.Linear(d_llm, d_keys * n_heads) | ||
self.value_projection = nn.Linear(d_llm, d_keys * n_heads) | ||
self.out_projection = nn.Linear(d_keys * n_heads, d_llm) | ||
self.n_heads = n_heads | ||
self.dropout = nn.Dropout(attention_dropout) | ||
|
||
|
||
|
||
def forward(self, target_embedding, source_embedding, value_embedding): | ||
B, L, _ = target_embedding.shape | ||
S, _ = source_embedding.shape | ||
H = self.n_heads | ||
|
||
''' | ||
print("target_embedding shape: ", target_embedding.shape) | ||
print("source_embed shape: ", source_embedding.shape) | ||
print("value_embed shape:", value_embedding.shape) | ||
print("d_model ", d_model) | ||
print("d_keys ", d_keys) | ||
print("d_llm ", d_llm) | ||
print("n_heads ", n_heads) | ||
''' | ||
|
||
|
||
target_embedding = self.query_projection(target_embedding).view(B, L, H, -1) | ||
source_embedding = self.key_projection(source_embedding).view(S, H, -1) | ||
value_embedding = self.value_projection(value_embedding).view(S, H, -1) | ||
|
||
out = self.reprogramming(target_embedding, source_embedding, value_embedding) | ||
|
||
out = out.reshape(B, L, -1) | ||
|
||
return self.out_projection(out) | ||
|
||
def reprogramming(self, target_embedding, source_embedding, value_embedding): | ||
B, L, H, E = target_embedding.shape | ||
|
||
scale = 1. / sqrt(E) | ||
|
||
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) | ||
|
||
A = self.dropout(torch.softmax(scale * scores, dim=-1)) | ||
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) | ||
|
||
return reprogramming_embedding |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.