Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
owang22 committed May 27, 2024
1 parent d14f25f commit 855d904
Show file tree
Hide file tree
Showing 80 changed files with 264,215 additions and 141 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
checkpoints/
results/
4 changes: 4 additions & 0 deletions =1.2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Collecting causal-conv1d
Downloading causal_conv1d-1.2.2.post1.tar.gz (7.2 kB)
Preparing metadata (setup.py): started
Preparing metadata (setup.py): finished with status 'error'
237 changes: 237 additions & 0 deletions TimeLLMcolab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from math import sqrt

import torch
import torch.nn as nn

from transformers import MambaModel, AutoTokenizer,MambaConfig
from layers.Embed import PatchEmbedding
import transformers
from layers.StandardNorm import Normalize

transformers.logging.set_verbosity_error()


class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)

def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x


class Model(nn.Module):

def __init__(self, configs, patch_len=16, stride=8):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len
self.seq_len = configs.seq_len
self.d_ff = configs.d_ff
self.top_k = 5
self.d_llm = configs.llm_dim
self.patch_len = configs.patch_len
self.stride = configs.stride


self.tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
self.mamba_config = MambaConfig.from_pretrained("state-spaces/mamba-130m-hf")

self.mamba_config.num_hidden_layers = configs.llm_layers
self.mamba_config.output_attentions = True
self.mamba_config.output_hidden_states = True


self.llm_model = MambaModel.from_pretrained(
"state-spaces/mamba-130m-hf",
config=self.mamba_config
)


if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
pad_token = '[PAD]'
self.tokenizer.add_special_tokens({'pad_token': pad_token})
self.tokenizer.pad_token = pad_token

for param in self.llm_model.parameters():
param.requires_grad = False

if configs.prompt_domain:
self.description = configs.content
else:
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'

self.dropout = nn.Dropout(configs.dropout)

self.patch_embedding = PatchEmbedding(
configs.d_model, self.patch_len, self.stride, configs.dropout)

self.word_embeddings = self.llm_model.get_input_embeddings().weight
self.vocab_size = self.word_embeddings.shape[0]
self.num_tokens = 1000
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)

print("d_model: ", configs.d_model)
print("n_heads: ", configs.n_heads)
print("d_ff: ", self.d_ff)
print("d_llm: ", self.d_llm)

self.reprogramming_layer = ReprogrammingLayer(configs.d_model, configs.n_heads, self.d_ff, self.d_llm)

self.patch_nums = int((configs.seq_len - self.patch_len) / self.stride + 2)

self.head_nf = self.d_ff * self.patch_nums
#print("self head_nf type: ", type(self.head_nf))
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.output_projection = FlattenHead(configs.enc_in, self.head_nf, self.pred_len,
head_dropout=configs.dropout)
else:
raise NotImplementedError

self.normalize_layers = Normalize(configs.enc_in, affine=False)

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :]
return None

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):

x_enc = self.normalize_layers(x_enc, 'norm')

B, T, N = x_enc.size()
x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

min_values = torch.min(x_enc, dim=1)[0]
max_values = torch.max(x_enc, dim=1)[0]
medians = torch.median(x_enc, dim=1).values
lags = self.calcute_lags(x_enc)
trends = x_enc.diff(dim=1).sum(dim=1)

prompt = []
for b in range(x_enc.shape[0]):
min_values_str = str(min_values[b].tolist()[0])
max_values_str = str(max_values[b].tolist()[0])
median_values_str = str(medians[b].tolist()[0])
lags_values_str = str(lags[b].tolist())
prompt_ = (
f"<|start_prompt|>Dataset description: {self.description}"
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
"Input statistics: "
f"min value {min_values_str}, "
f"max value {max_values_str}, "
f"median value {median_values_str}, "
f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
f"top 5 lags are : {lags_values_str}<|<end_prompt>|>"
)

prompt.append(prompt_)

x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)

source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

x_enc = x_enc.permute(0, 2, 1).contiguous()
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))

#print("enc out shape: ", enc_out.shape)
#print("source_embeddings input shape: ", source_embeddings.shape)

enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)

modelOut = self.llm_model(inputs_embeds=llama_enc_out)

#print("output Model: ", modelOut)

dec_out = (modelOut.last_hidden_state)
dec_out = dec_out[:, :, :self.d_ff]

dec_out = torch.reshape(
dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
dec_out = dec_out.permute(0, 1, 3, 2).contiguous()

dec_out = dec_out.to(torch.bfloat16)

dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:])
dec_out = dec_out.permute(0, 2, 1).contiguous()

dec_out = self.normalize_layers(dec_out, 'denorm')

return dec_out

def calcute_lags(self, x_enc):
q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1)
mean_value = torch.mean(corr, dim=1)
_, lags = torch.topk(mean_value, self.top_k, dim=-1)
return lags


class ReprogrammingLayer(nn.Module):
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
super(ReprogrammingLayer, self).__init__()

d_keys = d_keys or (d_model // n_heads)

self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
self.n_heads = n_heads
self.dropout = nn.Dropout(attention_dropout)



def forward(self, target_embedding, source_embedding, value_embedding):
B, L, _ = target_embedding.shape
S, _ = source_embedding.shape
H = self.n_heads

'''
print("target_embedding shape: ", target_embedding.shape)
print("source_embed shape: ", source_embedding.shape)
print("value_embed shape:", value_embedding.shape)
print("d_model ", d_model)
print("d_keys ", d_keys)
print("d_llm ", d_llm)
print("n_heads ", n_heads)
'''


target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
source_embedding = self.key_projection(source_embedding).view(S, H, -1)
value_embedding = self.value_projection(value_embedding).view(S, H, -1)

out = self.reprogramming(target_embedding, source_embedding, value_embedding)

out = out.reshape(B, L, -1)

return self.out_projection(out)

def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape

scale = 1. / sqrt(E)

scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

return reprogramming_embedding
Binary file added data_provider/__pycache__/__init__.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added data_provider/__pycache__/m4.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 855d904

Please sign in to comment.