Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
alfredgalichon committed Jan 17, 2024
1 parent 55a5227 commit d4602c2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
Binary file modified mec/__pycache__/nf.cpython-311.pyc
Binary file not shown.
63 changes: 59 additions & 4 deletions mec/nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,18 @@ def iterate(self, draw = False, verbose=0):
##################################################################
##################################################################

def adjust_label_pos(pos, edge_labels, shift_factor=0.1):
pos_labels = {}
for edge, label in edge_labels.items():
midpoint = ((pos[edge[0]][0] + pos[edge[1]][0]) / 2, (pos[edge[0]][1] + pos[edge[1]][1]) / 2)
pos_labels[edge] = (midpoint[0] + shift_factor, midpoint[1])
print(pos_labels)
return pos_labels


class Bipartite_EQF_problem:
def __init__(self, galois_xy, n_x,m_y, seed=777, verbose=0):

def __init__(self, galois_xy, n_x,m_y, label_galois_xy=None, seed=777, verbose=0):
self.n_x,self.m_y = n_x,m_y
self.nbx,self.nby = len(n_x),len(m_y)
self.nba = self.nbx*self.nby

Expand All @@ -403,10 +410,58 @@ def __init__(self, galois_xy, n_x,m_y, seed=777, verbose=0):
self.pos.update((node, (1, index)) for index, node in enumerate(bottom_nodes)) # Set one side for one set
self.pos.update((node, (2, index)) for index, node in enumerate(top_nodes))
self.galois_xy = galois_xy
self.label_galois_xy = label_galois_xy
#self.create_pricing_tree()

def draw(self,drawPrice = False, mu_a = None, plot_galois = False, entering_a = None, departing_a = None ,gain_a = None, figsize=(50, 30)):

edge_labels = {e: '' for e in self.digraph.edges()}

if plot_galois:
for e in self.digraph.edges():
edge_labels[e] += self.label_galois_xy[e]


if mu_a is not None:
for (i,e) in enumerate(self.arcsList):
if i in self.basis():
edge_labels[e] += '\nμ='+f"{mu_a[i]:.0f}"

if gain_a is not None:
for (i,e) in enumerate(self.arcsList):
if gain_a[i]!=self.c_a[i]:
edge_labels[e] += '\ng='+f"{gain_a[i]:.0f}"

# Adjust positions of labels
#label_pos = adjust_label_pos(self.pos, edge_labels, shift_factor=0.1)

nx.draw_networkx_edge_labels(self.digraph,self.pos,
edge_labels=edge_labels,
font_color='red')
# label_pos=label_pos)


q_z = np.concatenate([self.n_x,self.m_y])
labels = {z: f"{ ('n'*self.nbx+ 'm'*self.nby)[i]}={q_z[i]:.0f}"+'\n'+z+'\n' for i,z in enumerate(self.digraph.nodes()) }
label_pos = {z: (position[0], position[1] ) for z, position in self.pos.items()}
if drawPrice:
p_z = np.concatenate([np.zeros(1),p_z])
labels = {z: labels[z]+ f"\np={p_z[i]:.0f}" for i,z in enumerate(self.digraph.nodes())}

def draw(self):
nx.draw(self.digraph, self.pos, with_labels=True)
nx.draw_networkx_labels(self.digraph, label_pos, labels,font_size=10,verticalalignment = 'center')


if entering_a is not None:
nx.draw_networkx_edges(self.digraph, self.pos, edgelist=[self.arcsList[entering_a]],
edge_color='green', connectionstyle='arc3,rad=0.3')

if departing_a is not None:
nx.draw_networkx_edges(self.digraph, self.pos, edgelist=[self.arcsList[departing_a]],
style='dotted',
edge_color='white')

nx.draw(self.digraph, self.pos, with_labels=False)
plt.figure(figsize=figsize)
plt.show()

def create_pricing_tree(self, verbose = False):
Expand Down

0 comments on commit d4602c2

Please sign in to comment.