Skip to content

Commit

Permalink
Update nf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-jacquet committed Jan 11, 2024
1 parent fe250fa commit baa9832
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions mec/nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_connected_bipartite(nbx, nby, zero_node = 0, seed=777):


class Network_problem:
def __init__(self,nodesList,arcsList,c_a,q_z , active_basis = None, zero_node = 0, pos=None,seed=777, verbose=0):
def __init__(self, nodesList, arcsList, c_a, q_z, active_basis=None, zero_node=0, pos=None, seed=777, verbose=0):
self.zero_node = zero_node
self.nbz = len(nodesList)
self.nba = len(arcsList)
Expand All @@ -75,7 +75,7 @@ def __init__(self,nodesList,arcsList,c_a,q_z , active_basis = None, zero_node =
self.nodesDict = {node:node_ind for (node_ind,node) in enumerate(self.nodesList)}
self.arcsDict = {arc:arc_ind for (arc_ind,arc) in enumerate(self.arcsList)}
if verbose>1:
print('Number of nodes='+str(self.nbz)+'; number of arcs='+str(self.nba)+'.')
print('Number of nodes = '+str(self.nbz)+'; number of arcs = '+str(self.nba)+'.')

data = np.concatenate([-np.ones(self.nba),np.ones(self.nba)])
arcsIndices = list(range(self.nba))
Expand All @@ -90,23 +90,21 @@ def __init__(self,nodesList,arcsList,c_a,q_z , active_basis = None, zero_node =
self.nabla0_a_z = self.nabla_a_z[:,znotzero]

active_basis = two_phase(self.nabla0_a_z.T,self.q0_z)
assert len(active_basis)== (self.nbz - 1)
assert len(active_basis) == (self.nbz - 1)

self.digraph = nx.DiGraph()
self.digraph.add_edges_from(arcsList)
if pos is None:
pos = nx.spring_layout(self.digraph, seed=seed)
self.pos = pos


arcsnames = [str(x)+str(y) for (x,y) in self.arcsList]
arcsNames = [str(x)+str(y) for (x,y) in self.arcsList]

self.tableau = Tableau(A_i_j = np.asarray(np.linalg.solve(self.B(active_basis),self.N(active_basis))),
d_i = self.q0_z,
c_j = self.gain_a(active_basis)[self.nonbasis(active_basis)],
slack_var_names_i = [arcsnames[i] for i in active_basis] ,
decision_var_names_j = [arcsnames[i] for i in self.nonbasis(active_basis)]
)
d_i = self.q0_z,
c_j = self.gain_a(active_basis)[self.nonbasis(active_basis)],
slack_var_names_i = [arcsNames[i] for i in active_basis],
decision_var_names_j = [arcsNames[i] for i in self.nonbasis(active_basis)])
the_arcs_indices = list(active_basis) + list(set(range(self.nba)) - set(active_basis))
self.a_k = the_arcs_indices
self.k_a = {the_arcs_indices[k]:k for k in range(self.nba)}
Expand Down Expand Up @@ -237,9 +235,9 @@ def iterate(self, draw = False, verbose=0):
return(2)

class EQF_problem(Network_problem):
def __init__(self, nodesList,arcsList,galois_xy,q_z , active_basis = None, zero_node = 0, pos=None,seed=777, verbose=0):
def __init__(self, nodesList, arcsList, galois_xy, q_z, active_basis=None, zero_node=0, pos=None, seed=777, verbose=0):
c_a = np.zeros(len(arcsList))
Network_problem.__init__(self, nodesList,arcsList,c_a,q_z , active_basis, zero_node, pos,seed, verbose)
Network_problem.__init__(self, nodesList, arcsList, c_a, q_z, active_basis, zero_node, pos, seed, verbose)
self.galois_xy = galois_xy
self.create_pricing_tree()

Expand Down

0 comments on commit baa9832

Please sign in to comment.