-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ddff739
Showing
18 changed files
with
2,303 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import graphviz as gv | ||
from IPython.display import Image | ||
from IPython.display import display | ||
import functools | ||
from copy import deepcopy, copy | ||
import itertools | ||
import Lstar | ||
from random import randint, shuffle | ||
import random | ||
from time import clock | ||
import string | ||
|
||
digraph = functools.partial(gv.Digraph, format='png') | ||
graph = functools.partial(gv.Graph, format='png') | ||
|
||
separator = "_" | ||
|
||
class DFA: | ||
def __init__(self,obs_table): | ||
self.alphabet = obs_table.A #alphabet | ||
self.Q = [s for s in obs_table.S if s==obs_table.minimum_matching_row(s)] #avoid duplicate states | ||
self.q0 = obs_table.minimum_matching_row("") | ||
self.F = [s for s in self.Q if obs_table.T[s]== 1] | ||
self._make_transition_function(obs_table) | ||
|
||
def _make_transition_function(self,obs_table): | ||
self.delta = {} | ||
for s in self.Q: | ||
self.delta[s] = {} | ||
for a in self.alphabet: | ||
self.delta[s][a] = obs_table.minimum_matching_row(s+a) | ||
|
||
def classify_word(self,word): | ||
#assumes word is string with only letters in alphabet | ||
q = self.q0 | ||
for a in word: | ||
q = self.delta[q][a] | ||
return q in self.F | ||
|
||
def draw_nicely(self,force=False,maximum=60): #todo: if two edges are identical except for letter, merge them and note both the letters | ||
if (not force) and len(self.Q) > maximum: | ||
return | ||
|
||
#suspicion: graphviz may be upset by certain sequences, avoid them in nodes | ||
label_to_number_dict = {False:0} #false is never a label but gets us started | ||
def label_to_numberlabel(label): | ||
max_number = max(label_to_number_dict[l] for l in label_to_number_dict) | ||
if not label in label_to_number_dict: | ||
label_to_number_dict[label] = max_number + 1 | ||
return str(label_to_number_dict[label]) | ||
|
||
def add_nodes(graph, nodes): #stolen from http://matthiaseisen.com/articles/graphviz/ | ||
for n in nodes: | ||
if isinstance(n, tuple): | ||
graph.node(n[0], **n[1]) | ||
else: | ||
graph.node(n) | ||
return graph | ||
|
||
def add_edges(graph, edges): #stolen from http://matthiaseisen.com/articles/graphviz/ | ||
for e in edges: | ||
if isinstance(e[0], tuple): | ||
graph.edge(*e[0], **e[1]) | ||
else: | ||
graph.edge(*e) | ||
return graph | ||
|
||
g = digraph() | ||
g = add_nodes(g, [(label_to_numberlabel(self.q0), {'color':'green' if self.q0 in self.F else 'black', | ||
'shape': 'hexagon', 'label':'start'})]) | ||
states = list(set(self.Q)-{self.q0}) | ||
g = add_nodes(g, [(label_to_numberlabel(state),{'color': 'green' if state in self.F else 'black', | ||
'label': str(i)}) | ||
for state,i in zip(states,range(1,len(states)+1))]) | ||
|
||
def group_edges(): | ||
def clean_line(line,group): | ||
line = line.split(separator) | ||
line = sorted(line) + ["END"] | ||
in_sequence= False | ||
last_a = "" | ||
clean = line[0] | ||
if line[0] in group: | ||
in_sequence = True | ||
first_a = line[0] | ||
last_a = line[0] | ||
for a in line[1:]: | ||
if in_sequence: | ||
if a in group and (ord(a)-ord(last_a))==1: #continue sequence | ||
last_a = a | ||
else: #break sequence | ||
#finish sequence that was | ||
if (ord(last_a)-ord(first_a))>1: | ||
clean += ("-" + last_a) | ||
elif not last_a == first_a: | ||
clean += (separator + last_a) | ||
#else: last_a==first_a -- nothing to add | ||
in_sequence = False | ||
#check if there is a new one | ||
if a in group: | ||
first_a = a | ||
last_a = a | ||
in_sequence = True | ||
if not a=="END": | ||
clean += (separator + a) | ||
else: | ||
if a in group: #start sequence | ||
first_a = a | ||
last_a = a | ||
in_sequence = True | ||
if not a=="END": | ||
clean += (separator+a) | ||
return clean | ||
|
||
|
||
edges_dict = {} | ||
for state in self.Q: | ||
for a in self.alphabet: | ||
edge_tuple = (label_to_numberlabel(state),label_to_numberlabel(self.delta[state][a])) | ||
# print(str(edge_tuple)+" "+a) | ||
if not edge_tuple in edges_dict: | ||
edges_dict[edge_tuple] = a | ||
else: | ||
edges_dict[edge_tuple] += separator+a | ||
# print(str(edge_tuple)+" = "+str(edges_dict[edge_tuple])) | ||
for et in edges_dict: | ||
edges_dict[et] = clean_line(edges_dict[et], string.ascii_lowercase) | ||
edges_dict[et] = clean_line(edges_dict[et], string.ascii_uppercase) | ||
edges_dict[et] = clean_line(edges_dict[et], "0123456789") | ||
edges_dict[et] = edges_dict[et].replace(separator,",") | ||
return edges_dict | ||
|
||
edges_dict = group_edges() | ||
g = add_edges(g,[(e,{'label':edges_dict[e]}) for e in edges_dict]) | ||
# print('\n'.join([str(((str(state),str(self.delta[state][a])),{'label':a})) for a in self.alphabet for state in | ||
# self.Q])) | ||
# g = add_edges(g,[((label_to_numberlabel(state),label_to_numberlabel(self.delta[state][a])),{'label':a}) | ||
# for a in self.alphabet for state in self.Q]) | ||
display(Image(filename=g.render(filename='img/automaton'))) | ||
|
||
def minimal_diverging_suffix(self,state1,state2): #gets series of letters showing the two states are different, | ||
# i.e., from which one state reaches accepting state and the other reaches rejecting state | ||
# assumes of course that the states are in the automaton and actually not equivalent | ||
res = None | ||
# just use BFS til you reach an accepting state | ||
# after experiments: attempting to use symmetric difference on copies with s1,s2 as the starting state, or even | ||
# just make and minimise copies of this automaton starting from s1 and s2 before starting the BFS, | ||
# is slower than this basic BFS, so don't | ||
seen_states = set() | ||
new_states = {("",(state1,state2))} | ||
while len(new_states) > 0: | ||
prefix,state_pair = new_states.pop() | ||
s1,s2 = state_pair | ||
if len([q for q in [s1,s2] if q in self.F])== 1: # intersection of self.F and [s1,s2] is exactly one state, | ||
# meaning s1 and s2 are classified differently | ||
res = prefix | ||
break | ||
seen_states.add(state_pair) | ||
for a in self.alphabet: | ||
next_state_pair = (self.delta[s1][a],self.delta[s2][a]) | ||
next_tuple = (prefix+a,next_state_pair) | ||
if not next_tuple in new_states and not next_state_pair in seen_states: | ||
new_states.add(next_tuple) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from time import clock | ||
from ObservationTable import TableTimedOut | ||
from DFA import DFA | ||
from Teacher import Teacher | ||
from Lstar import run_lstar | ||
|
||
def extract(rnn,time_limit = 50,initial_split_depth = 10,starting_examples=None): | ||
print("provided counterexamples are:",starting_examples) | ||
guided_teacher = Teacher(rnn,num_dims_initial_split=initial_split_depth,starting_examples=starting_examples) | ||
start = clock() | ||
try: | ||
run_lstar(guided_teacher,time_limit) | ||
except KeyboardInterrupt: #you can press the stop button in the notebook to stop the extraction any time | ||
print("lstar extraction terminated by user") | ||
except TableTimedOut: | ||
print("observation table timed out during refinement") | ||
end = clock() | ||
extraction_time = end-start | ||
|
||
dfa = guided_teacher.dfas[-1] | ||
|
||
print("overall guided extraction time took: " + str(extraction_time)) | ||
|
||
print("generated counterexamples were: (format: (counterexample, counterexample generation time))") | ||
print('\n'.join([str(a) for a in guided_teacher.counterexamples_with_times])) | ||
return dfa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import dynet as dy | ||
from Helper_Functions import map_nested_dict | ||
|
||
class GRUCell: | ||
def __init__(self,input_dim,output_dim,pc): | ||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.pc = pc | ||
self.gate_names = ["z","r","htilde"] | ||
self.gate_activations = [dy.logistic]*2+[dy.tanh] | ||
#todo: generalise to layers | ||
self.parameters = {"W":{"x":{},"h":{}},"b":{}} | ||
for gate in self.gate_names: | ||
self.parameters["W"]["x"][gate] = self.pc.add_parameters((self.output_dim,self.input_dim)) | ||
self.parameters["W"]["h"][gate] = self.pc.add_parameters((self.output_dim,self.output_dim)) #takes its own previous output | ||
self.parameters["b"][gate] = self.pc.add_parameters((self.output_dim)) | ||
self.parameters["h0"] = self.pc.add_parameters((self.output_dim)) | ||
self.parameters["h0"].clip_inplace(-1,1) | ||
|
||
self.store_expressions() | ||
|
||
|
||
def store_expressions(self): | ||
self.expressions = map_nested_dict(self.parameters,dy.parameter) | ||
self.parameters["h0"].clip_inplace(-1,1) | ||
self.initial_h = self.parameters["h0"].expr() | ||
|
||
|
||
|
||
def gate_vecs(self,ht1,xt): | ||
b = self.expressions["b"] | ||
W = self.expressions["W"] | ||
gate_vecs = {} | ||
for g,activation in zip(self.gate_names,self.gate_activations): | ||
hin = ht1 if not g=="htilde" else dy.cmult(gate_vecs["r"],ht1) | ||
gate_vecs[g] = activation(dy.affine_transform([b[g], | ||
W["x"][g],xt, | ||
W["h"][g],ht1])) | ||
return gate_vecs | ||
|
||
def gate_and_next_vecs(self,ht1,xt): | ||
v = self.gate_vecs(ht1,xt) | ||
h = dy.cmult(v["z"],ht1)+dy.cmult(1-v["z"],v["htilde"]) | ||
res = v | ||
res.update({"h":h}) | ||
return res | ||
|
||
from functools import reduce | ||
from operator import add | ||
class GRUNetworkState: | ||
def __init__(self,hs=None,full_vec=None,hidden_dim=None): | ||
if not None in [full_vec,hidden_dim]: | ||
hvec = full_vec | ||
self.hs = [dy.inputVector(hvec[i*hidden_dim:(i+1)*hidden_dim]) for i in range(int(len(hvec)/hidden_dim))] | ||
elif not None in [hs]: | ||
self.hs = hs #list of h expressions | ||
else: | ||
raise MissingInput() | ||
|
||
def output(self): | ||
return self.hs[-1] | ||
|
||
def as_vec(self): | ||
return reduce(add,[h.value() for h in self.hs]) | ||
# return np.concatenate([h.npvalue() for h in self.hs]).tolist() | ||
|
||
|
||
class GRUNetwork: | ||
def __init__(self,num_layers=None,input_dim=None,hidden_dim=None,pc=None,output_dim=None): | ||
if None in [num_layers,input_dim,hidden_dim,pc] or (num_layers <= 0): | ||
raise MissingInput() | ||
if None == output_dim: | ||
output_dim = hidden_dim | ||
|
||
self.num_layers = num_layers | ||
self.input_dim = input_dim | ||
self.hidden_dim = hidden_dim | ||
self.output_dim = output_dim | ||
self.pc = pc | ||
self.state_class = GRUNetworkState | ||
|
||
self.layers = [] | ||
if self.num_layers > 1: | ||
self.layers.append(GRUCell(self.input_dim,self.hidden_dim,self.pc)) | ||
for _ in range(num_layers-2): | ||
self.layers.append(GRUCell(self.hidden_dim,self.hidden_dim,self.pc)) | ||
self.layers.append(GRUCell(self.hidden_dim,self.output_dim,self.pc)) | ||
else: | ||
self.layers.append(GRUCell(self.input_dim,self.output_dim,self.pc)) | ||
|
||
def all_gate_and_next_vecs(self,state,input_vec): | ||
res = [] | ||
x = input_vec | ||
for layer,h in zip(self.layers,state.hs): | ||
res.append(layer.gate_and_next_vecs(h,x)) | ||
x = res[-1]["h"] #output of one layer is input to the next | ||
return res | ||
|
||
def next_state(self,state,input_vec): | ||
v = self.all_gate_and_next_vecs(state,input_vec) | ||
hs = [lvals["h"] for lvals in v] | ||
return GRUNetworkState(hs=hs) | ||
|
||
def store_expressions(self): | ||
for l in self.layers: | ||
l.store_expressions() | ||
self.initial_state = GRUNetworkState(hs=[l.initial_h for l in self.layers]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import itertools | ||
import random | ||
|
||
def mean(num_list): | ||
return sum(num_list)*1.0/len(num_list) | ||
|
||
def n_words_of_length(n,length,alphabet): | ||
if 50*n >= pow(len(alphabet),length): | ||
res = all_words_of_length(length, alphabet) | ||
random.shuffle(res) | ||
return res[:n] | ||
#else if 50*n < total words to be found, i.e. looking for 1/50th of the words or less | ||
res = set() | ||
while len(res)<n: | ||
word = "" | ||
for _ in range(length): | ||
word += random.choice(alphabet) | ||
res.add(word) | ||
return list(res) | ||
|
||
def all_words_of_length(length,alphabet): | ||
return [''.join(list(b)) for b in itertools.product(alphabet, repeat=length)] | ||
|
||
|
||
def compare(network,classifier,length,num_examples=1000,provided_samples=None): | ||
if not None == provided_samples: | ||
words = provided_samples | ||
else: | ||
words = n_words_of_length(num_examples,length,network.alphabet) | ||
disagreeing_words = [w for w in words if not (network.classify_word(w) == classifier.classify_word(w))] | ||
return 1-(len(disagreeing_words)/len(words)), disagreeing_words | ||
|
||
def map_nested_dict(d,mapper): | ||
if not isinstance(d,dict): | ||
return mapper(d) | ||
return {k:map_nested_dict(d[k],mapper) for k in d} | ||
|
||
class MissingInput(Exception): | ||
pass |
Oops, something went wrong.