Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/division property #291

Merged
merged 8 commits into from
Dec 6, 2024
173 changes: 106 additions & 67 deletions claasp/cipher_modules/division_trail_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from collections import Counter
from sage.rings.polynomial.pbori.pbori import BooleanPolynomialRing
from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids, _get_predecessors_subgraph
from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component
from gurobipy import Model, GRB
import os

verbosity = False

class MilpDivisionTrailModel():
"""

Expand All @@ -36,6 +39,7 @@ class MilpDivisionTrailModel():
This module can only be used if the user possesses a Gurobi license.

"""

def __init__(self, cipher):
self._cipher = cipher
self._variables = None
Expand All @@ -44,13 +48,20 @@ def __init__(self, cipher):
self._used_variables = []
self._variables_as_list = []
self._unused_variables = []
self._used_predecessors_sorted = None
self._output_id = None
self._output_bit_index_previous_comp = None
self._block_needed = None
self._input_id_link_needed = None

def get_all_variables_as_list(self):
for component_id in list(self._variables.keys())[:-1]:
for bit_position in self._variables[component_id].keys():
for value in self._variables[component_id][bit_position].keys():
if value != "current":
self._variables_as_list.append(self._variables[component_id][bit_position][value].VarName)
varname = self._variables[component_id][bit_position][value].VarName
if varname not in self._variables_as_list: # rot and intermediate has the same name than original
self._variables_as_list.append(varname)

def get_unused_variables(self):
self.get_all_variables_as_list()
Expand Down Expand Up @@ -81,9 +92,7 @@ def build_gurobi_model(self):
model = Model("basic_model", env=env)
# model = Model()
model.Params.LogToConsole = 0
model.Params.Threads = 16 # best found experimentaly on ascon_sbox_2rounds
model.setParam("PoolSolutions", 1234) # 200000000
model.setParam(GRB.Param.PoolSearchMode, 2)
# model.Params.Threads = 16
self._model = model

def get_anfs_from_sbox(self, component):
Expand Down Expand Up @@ -183,7 +192,6 @@ def add_sbox_constraints(self, component):
x = B.variable_names()
anfs = self.get_anfs_from_sbox(component)
anfs = [B(anfs[i]) for i in range(component.input_bit_size)]
# print(anfs)

copy_monomials_deg = self.create_gurobi_vars_sbox(component, input_vars_concat)

Expand Down Expand Up @@ -215,6 +223,52 @@ def add_sbox_constraints(self, component):
self._model.addConstr(output_vars[index] >= constr)
self._model.update()

def create_copies_for_linear_layer(self, binary_matrix, input_vars_concat):
copies = {}
for index, var in enumerate(input_vars_concat):
column = [row[index] for row in binary_matrix]
number_of_1s = list(column).count(1)
if number_of_1s > 1:
current = 1
else:
current = 0
copies[index] = {}
copies[index][0] = var
copies[index]["current"] = current
self.set_as_used_variables([var])
new_vars = self._model.addVars(list(range(number_of_1s)), vtype=GRB.BINARY,
name="copy_" + var.VarName)
self._model.update()
for i in range(number_of_1s):
self._model.addConstr(var >= new_vars[i])
self._model.addConstr(
sum(new_vars[i] for i in range(number_of_1s)) >= var)
self._model.update()
for i in range(1, number_of_1s + 1):
copies[index][i] = new_vars[i - 1]
return copies

def add_linear_layer_constraints(self, component):
output_vars = self.get_output_vars(component)
input_vars_concat = self.get_input_vars(component)

if component.type == "linear_layer":
binary_matrix = component.description
else:
binary_matrix = binary_matrix_of_linear_component(component)

copies = self.create_copies_for_linear_layer(binary_matrix, input_vars_concat)
for index_row, row in enumerate(binary_matrix):
constr = 0
for index_bit, bit in enumerate(row):
if bit:
current = copies[index_bit]["current"]
constr += copies[index_bit][current]
copies[index_bit]["current"] += 1
self.set_as_used_variables([copies[index_bit][current]])
self._model.addConstr(output_vars[index_row] == constr)
self._model.update()

def add_xor_constraints(self, component):
output_vars = self.get_output_vars(component)

Expand All @@ -230,20 +284,15 @@ def add_xor_constraints(self, component):
else:
input_vars_concat.append(self._variables[input_name][pos][current])
self._variables[input_name][pos]["current"] += 1
# print(input_vars_concat)

block_size = component.output_bit_size
nb_blocks = component.description[1]
if constant_flag != []:
nb_blocks -= 1
# print(self._occurences[component.id])
# print(list(self._occurences[component.id].keys()))
# print(len(list(self._occurences[component.id].keys())))
for index, bit_pos in enumerate(list(self._occurences[component.id].keys())):
constr = 0
for j in range(nb_blocks):
constr += input_vars_concat[index + block_size * j]
# print(input_vars_concat[index + block_size * j])
self.set_as_used_variables([input_vars_concat[index + block_size * j]])
if (constant_flag != []) and (constant_flag[index]):
self._model.addConstr(output_vars[index] >= constr)
Expand Down Expand Up @@ -372,12 +421,14 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed):
self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed)

used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys()))
self._used_predecessors_sorted = used_predecessors_sorted
for component_id in used_predecessors_sorted:
if component_id not in self._cipher.inputs:
component = self._cipher.get_component_from_id(component_id)
print(f"---------> {component.id}")
if component.type == "sbox":
self.add_sbox_constraints(component)
elif component.type in ["linear_layer", "mix_column"]:
self.add_linear_layer_constraints(component)
elif component.type in ["cipher_output", "constant", "intermediate_output"]:
continue
elif component.type == "word_operation":
Expand Down Expand Up @@ -414,13 +465,9 @@ def get_where_component_is_used(self, predecessors, input_id_link_needed, block_
component = self._cipher.get_component_from_id(input_id_link_needed)
occurences[input_id_link_needed] = [[i for i in range(component.output_bit_size)]]

# print("occurences")
# print(occurences)
occurences_final = {}
for component_id in occurences.keys():
occurences_final[component_id] = self.find_copy_indexes(occurences[component_id])
# print("occurences_final")
# print(occurences_final)

self._occurences = occurences_final
return occurences_final
Expand Down Expand Up @@ -462,13 +509,10 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee
occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
all_vars = {}
used_predecessors_sorted = self.order_predecessors(list(occurences.keys()))
print("used_predecessors_sorted")
print(used_predecessors_sorted)
for component_id in used_predecessors_sorted:
all_vars[component_id] = {}
# We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method.
# That's why we split the following loop: we first created the original vars, and then the copies vars when necessary.
# print(f"###### {component_id}")
if component_id[:3] == "rot":
component = self._cipher.get_component_from_id(component_id)
rotate_offset = component.description[1]
Expand Down Expand Up @@ -571,20 +615,15 @@ def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, c
block_needed = comp.input_bit_positions[index]
input_id_link_needed = chosen_cipher_output
output_bit_index_previous_comp = output_bit_index_ciphertext
print(output_id)
print(block_needed)
print(input_id_link_needed)
print(output_bit_index_previous_comp)
return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot
else:
output_id = self.get_cipher_output_component_id()
# output_id = "xor_1_69"
component = self._cipher.get_component_from_id(output_id)
pivot = 0
output_bit_index_previous_comp = output_bit_index_ciphertext
for index, block in enumerate(component.input_bit_positions):
if pivot <= output_bit_index_ciphertext < pivot + len(block):
output_bit_index_previous_comp = output_bit_index_ciphertext - pivot
output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot]
block_needed = block
input_id_link_needed = component.input_id_links[index]
break
Expand All @@ -609,31 +648,28 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component(
output_bit_index_ciphertext, chosen_cipher_output)

self._output_id = output_id
self._output_bit_index_previous_comp = output_bit_index_previous_comp
self._block_needed = block_needed
self._input_id_link_needed = input_id_link_needed

G = create_networkx_graph_from_input_ids(self._cipher)
predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed]))
for input_id in self._cipher.inputs + ['']:
if input_id in predecessors:
predecessors.remove(input_id)

# print("input_id_link_needed")
# print(input_id_link_needed)
# print("predecessors")
# print(predecessors)
self.add_constraints(predecessors, input_id_link_needed, block_needed)

var_from_block_needed = []
for i in block_needed:
var_from_block_needed.append(self._variables[input_id_link_needed][i][0])
# print("var_from_block_needed")
# print(var_from_block_needed)

output_vars = self._model.addVars(list(range(pivot, pivot + len(block_needed))), vtype=GRB.BINARY,
name=output_id)
self._variables[output_id] = output_vars
output_vars = list(output_vars.values())
self._model.update()
# print("output_vars")
# print(output_vars)

for i in range(len(block_needed)):
self._model.addConstr(output_vars[i] == var_from_block_needed[i])
Expand All @@ -654,10 +690,10 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex

self.set_unused_variables_to_zero()
self._model.update()
self._model.write("division_trail_model.lp")
end = time.time()
building_time = end - start
print(f"########## building_time : {building_time}")
if verbosity:
print(f"########## building_time : {building_time}")
self._model.update()

def get_solutions(self):
Expand All @@ -676,7 +712,6 @@ def get_solutions(self):
first_input_bit_positions = list(self._occurences[self._cipher.inputs[0]].keys())

solCount = self._model.SolCount
print('Number of solutions (might cancel each other) found: ' + str(solCount))
monomials = []
for sol in range(solCount):
self._model.setParam(GRB.Param.SolutionNumber, sol)
Expand All @@ -695,67 +730,80 @@ def get_solutions(self):
else:
if index < len(list(self._occurences[self._cipher.inputs[0]].keys())):
tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index])
if 1 not in values[:max_input_bit_pos]:
tmp += str(1)
else:
if nb_inputs_used == 1:
input1_prefix = self._cipher.inputs[0][0]
l = tmp.split(input1_prefix)[1:]
sorted_l = sorted(l, key=lambda x: (x == '', int(x) if x else 0))
l = [''] + sorted_l
tmp = input1_prefix.join(l)

if tmp in monomials:
monomials.remove(tmp)
else:
monomials.append(tmp)

end = time.time()
printing_time = end - start
print(f"########## printing_time : {printing_time}")
print(monomials)
print(f'Number of monomials found: {len(monomials)}')
if verbosity:
print('Number of solutions (might cancel each other) found: ' + str(solCount))
print(f"########## printing_time : {printing_time}")
print(f'Number of monomials found: {len(monomials)}')
return monomials

def optimize_model(self):
print(self._model)
start = time.time()
self._model.optimize()
end = time.time()
solving_time = end - start
print(f"########## solving_time : {solving_time}")
if verbosity:
print(self._model)
print(f"########## solving_time : {solving_time}")

def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)

# # Specific to Aradi analysis:
# for i in range(96):
# v = self._model.getVarByName(f"plaintext[{i}]")
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
# ########################
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)
self._model.write("division_trail_model.lp")

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_specific_anf(self, monomial, output_bit_index, fixed_degree=None,
chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)

for term in monomial:
var_term = self._model.getVarByName(f"{term[0]}[{term[1]}]")
self._model.addConstr(var_term == 1)
self._model.update()
self._model.write("division_trail_model.lp")

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None, chosen_cipher_output=None):
def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None,
chosen_cipher_output=None):
s = ""
for term in monomial:
s += term[0][0] + str(term[1])
for i in range(self._cipher.output_bit_size):
print(f"\nSearch of {s} in anf {i} :")
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree, chosen_cipher_output)
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree,
chosen_cipher_output)

def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_output=None):
fixed_degree = None
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam(GRB.Param.PoolSearchMode, 1)
self._model.setParam('Presolve', 2)
self._model.setParam('MIPFocus', 3)
# self._model.setParam('Cuts', 2)
self._model.setParam('NodefileStart', 2.0)
self._model.setParam("MIPFocus", 2)
self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution
self._model.setParam('Cuts', 2)

index_plaintext = self._cipher.inputs.index("plaintext")
plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext]
Expand All @@ -765,19 +813,10 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out
p.append(self._model.getVarByName(f"plaintext[{i}]"))
self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE)

## Specific to Aradi analysis:
# for i in range(128):
# v = self._model.getVarByName(f"plaintext[{i}]")
# if 0 <= i < 128: # free vars
# self._model.addConstr(v >= 0)
# else:
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
#######################

self._model.update()
self._model.write("division_trail_model.lp")
self.optimize_model()
# get degree

degree = self._model.getObjective().getValue()
return degree

Expand Down
Loading
Loading