Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/division property #291

Merged
merged 8 commits into from
Dec 6, 2024
150 changes: 104 additions & 46 deletions claasp/cipher_modules/division_trail_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections import Counter
from sage.rings.polynomial.pbori.pbori import BooleanPolynomialRing
from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids, _get_predecessors_subgraph
from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component
from gurobipy import Model, GRB
import os

Expand All @@ -36,6 +37,7 @@ class MilpDivisionTrailModel():
This module can only be used if the user possesses a Gurobi license.

"""

def __init__(self, cipher):
self._cipher = cipher
self._variables = None
Expand All @@ -44,13 +46,20 @@ def __init__(self, cipher):
self._used_variables = []
self._variables_as_list = []
self._unused_variables = []
self._used_predecessors_sorted = None
self._output_id = None
self._output_bit_index_previous_comp = None
self._block_needed = None
self._input_id_link_needed = None

def get_all_variables_as_list(self):
for component_id in list(self._variables.keys())[:-1]:
for bit_position in self._variables[component_id].keys():
for value in self._variables[component_id][bit_position].keys():
if value != "current":
self._variables_as_list.append(self._variables[component_id][bit_position][value].VarName)
varname = self._variables[component_id][bit_position][value].VarName
if varname not in self._variables_as_list: # rot and intermediate has the same name than original
self._variables_as_list.append(varname)

def get_unused_variables(self):
self.get_all_variables_as_list()
Expand Down Expand Up @@ -81,9 +90,7 @@ def build_gurobi_model(self):
model = Model("basic_model", env=env)
# model = Model()
model.Params.LogToConsole = 0
model.Params.Threads = 16 # best found experimentaly on ascon_sbox_2rounds
model.setParam("PoolSolutions", 1234) # 200000000
model.setParam(GRB.Param.PoolSearchMode, 2)
# model.Params.Threads = 16
self._model = model

def get_anfs_from_sbox(self, component):
Expand Down Expand Up @@ -215,12 +222,61 @@ def add_sbox_constraints(self, component):
self._model.addConstr(output_vars[index] >= constr)
self._model.update()

def create_copies_for_linear_layer(self, binary_matrix, input_vars_concat):
copies = {}
for index, var in enumerate(input_vars_concat):
column = [row[index] for row in binary_matrix]
number_of_1s = list(column).count(1)
if number_of_1s > 1:
current = 1
else:
current = 0
copies[index] = {}
copies[index][0] = var
copies[index]["current"] = current
self.set_as_used_variables([var])
new_vars = self._model.addVars(list(range(number_of_1s)), vtype=GRB.BINARY,
name="copy_" + var.VarName)
self._model.update()
for i in range(number_of_1s):
self._model.addConstr(var >= new_vars[i])
self._model.addConstr(
sum(new_vars[i] for i in range(number_of_1s)) >= var)
self._model.update()
for i in range(1, number_of_1s + 1):
copies[index][i] = new_vars[i - 1]
return copies

def add_linear_layer_constraints(self, component):
output_vars = self.get_output_vars(component)
input_vars_concat = self.get_input_vars(component)

if component.type == "linear_layer":
binary_matrix = component.description
else:
binary_matrix = binary_matrix_of_linear_component(component)

copies = self.create_copies_for_linear_layer(binary_matrix, input_vars_concat)
for index_row, row in enumerate(binary_matrix):
constr = 0
for index_bit, bit in enumerate(row):
if bit:
current = copies[index_bit]["current"]
constr += copies[index_bit][current]
copies[index_bit]["current"] += 1
self.set_as_used_variables([copies[index_bit][current]])
self._model.addConstr(output_vars[index_row] == constr)
self._model.update()

def add_xor_constraints(self, component):
output_vars = self.get_output_vars(component)
# print(output_vars)

input_vars_concat = []
constant_flag = []
for index, input_name in enumerate(component.input_id_links):
# print(input_name)
# print(self._variables[input_name])
for pos in component.input_bit_positions[index]:
current = self._variables[input_name][pos]["current"]
if input_name[:8] == "constant":
Expand Down Expand Up @@ -372,12 +428,15 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed):
self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed)

used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys()))
self._used_predecessors_sorted = used_predecessors_sorted
for component_id in used_predecessors_sorted:
if component_id not in self._cipher.inputs:
component = self._cipher.get_component_from_id(component_id)
print(f"---------> {component.id}")
# print(f"---------> {component.id}")
if component.type == "sbox":
self.add_sbox_constraints(component)
elif component.type in ["linear_layer", "mix_column"]:
self.add_linear_layer_constraints(component)
elif component.type in ["cipher_output", "constant", "intermediate_output"]:
continue
elif component.type == "word_operation":
Expand Down Expand Up @@ -462,8 +521,8 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee
occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
all_vars = {}
used_predecessors_sorted = self.order_predecessors(list(occurences.keys()))
print("used_predecessors_sorted")
print(used_predecessors_sorted)
# print("used_predecessors_sorted")
# print(used_predecessors_sorted)
for component_id in used_predecessors_sorted:
all_vars[component_id] = {}
# We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method.
Expand Down Expand Up @@ -571,20 +630,15 @@ def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, c
block_needed = comp.input_bit_positions[index]
input_id_link_needed = chosen_cipher_output
output_bit_index_previous_comp = output_bit_index_ciphertext
print(output_id)
print(block_needed)
print(input_id_link_needed)
print(output_bit_index_previous_comp)
return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot
else:
output_id = self.get_cipher_output_component_id()
# output_id = "xor_1_69"
component = self._cipher.get_component_from_id(output_id)
pivot = 0
output_bit_index_previous_comp = output_bit_index_ciphertext
for index, block in enumerate(component.input_bit_positions):
if pivot <= output_bit_index_ciphertext < pivot + len(block):
output_bit_index_previous_comp = output_bit_index_ciphertext - pivot
output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot]
block_needed = block
input_id_link_needed = component.input_id_links[index]
break
Expand All @@ -608,17 +662,21 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
start = time.time()
output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component(
output_bit_index_ciphertext, chosen_cipher_output)
# print(output_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments to be removed

# print(block_needed)
# print(input_id_link_needed)
# print(output_bit_index_previous_comp)
self._output_id = output_id
self._output_bit_index_previous_comp = output_bit_index_previous_comp
self._block_needed = block_needed
self._input_id_link_needed = input_id_link_needed

G = create_networkx_graph_from_input_ids(self._cipher)
predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed]))
for input_id in self._cipher.inputs + ['']:
if input_id in predecessors:
predecessors.remove(input_id)

# print("input_id_link_needed")
# print(input_id_link_needed)
# print("predecessors")
# print(predecessors)
self.add_constraints(predecessors, input_id_link_needed, block_needed)

var_from_block_needed = []
Expand Down Expand Up @@ -654,7 +712,7 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex

self.set_unused_variables_to_zero()
self._model.update()
self._model.write("division_trail_model.lp")
# self._model.write("division_trail_model.lp")
end = time.time()
building_time = end - start
print(f"########## building_time : {building_time}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build time should be one of the outputs, as done in the trail search.
The output can be a json

Expand Down Expand Up @@ -695,6 +753,16 @@ def get_solutions(self):
else:
if index < len(list(self._occurences[self._cipher.inputs[0]].keys())):
tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index])
if 1 not in values[:max_input_bit_pos]:
tmp += str(1)
else:
if nb_inputs_used == 1:
input1_prefix = self._cipher.inputs[0][0]
l = tmp.split(input1_prefix)[1:]
sorted_l = sorted(l, key=lambda x: (x == '', int(x) if x else 0))
l = [''] + sorted_l
tmp = input1_prefix.join(l)

if tmp in monomials:
monomials.remove(tmp)
else:
Expand All @@ -703,8 +771,8 @@ def get_solutions(self):
end = time.time()
printing_time = end - start
print(f"########## printing_time : {printing_time}")
print(monomials)
print(f'Number of monomials found: {len(monomials)}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to print

return monomials

def optimize_model(self):
print(self._model)
Expand All @@ -716,46 +784,45 @@ def optimize_model(self):

def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)

# # Specific to Aradi analysis:
# for i in range(96):
# v = self._model.getVarByName(f"plaintext[{i}]")
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
# ########################
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_specific_anf(self, monomial, output_bit_index, fixed_degree=None,
chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)

for term in monomial:
var_term = self._model.getVarByName(f"{term[0]}[{term[1]}]")
self._model.addConstr(var_term == 1)
self._model.update()
self._model.write("division_trail_model.lp")

self.optimize_model()
self.get_solutions()
return self.get_solutions()

def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None, chosen_cipher_output=None):
def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None,
chosen_cipher_output=None):
s = ""
for term in monomial:
s += term[0][0] + str(term[1])
for i in range(self._cipher.output_bit_size):
print(f"\nSearch of {s} in anf {i} :")
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree, chosen_cipher_output)
self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree,
chosen_cipher_output)

def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_output=None):
fixed_degree = None
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam(GRB.Param.PoolSearchMode, 1)
self._model.setParam('Presolve', 2)
self._model.setParam('MIPFocus', 3)
# self._model.setParam('Cuts', 2)
self._model.setParam('NodefileStart', 2.0)
self._model.setParam("MIPFocus", 2)
self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution
self._model.setParam('Cuts', 2)

index_plaintext = self._cipher.inputs.index("plaintext")
plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext]
Expand All @@ -765,19 +832,10 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out
p.append(self._model.getVarByName(f"plaintext[{i}]"))
self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE)

## Specific to Aradi analysis:
# for i in range(128):
# v = self._model.getVarByName(f"plaintext[{i}]")
# if 0 <= i < 128: # free vars
# self._model.addConstr(v >= 0)
# else:
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
#######################

self._model.update()
self._model.write("division_trail_model.lp")
self.optimize_model()
# get degree

degree = self._model.getObjective().getValue()
return degree

Expand Down
54 changes: 41 additions & 13 deletions tests/unit/cipher_modules/division_trail_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,46 @@
from claasp.ciphers.permutations.gaston_sbox_permutation import GastonSboxPermutation
from claasp.ciphers.block_ciphers.aradi_block_cipher import AradiBlockCipher
from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher
from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher
from claasp.cipher_modules.division_trail_search import *

def test_get_where_component_is_used():
cipher = SimonBlockCipher(number_of_rounds=1)
"""

Given a number of rounds of a chosen cipher and a chosen output bit, this module produces a model that can either:
- obtain the ANF of this chosen output bit,
- find the degree of this ANF,
- or check the presence or absence of a specified monomial.

This module can only be used if the user possesses a Gurobi license.

"""

def test_find_anf_of_specific_output_bit():
# Return the monomials of the anf of the chosen output bit
cipher = SimonBlockCipher(number_of_rounds=2)
milp = MilpDivisionTrailModel(cipher)
predecessors = ['intermediate_output_0_0', 'rot_0_1', 'rot_0_2', 'rot_0_3', 'and_0_4', 'xor_0_5', 'xor_0_6', 'intermediate_output_0_7', 'cipher_output_0_8']
input_id_link_needed = 'xor_0_6'
block_needed = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
occurences = milp.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
assert list(occurences.keys()) == ['plaintext', 'key', 'rot_0_1', 'rot_0_2', 'rot_0_3', 'and_0_4', 'xor_0_5', 'xor_0_6']
monomials = milp.find_anf_of_specific_output_bit(0)
assert monomials == ['p18','k32','p0','p3p24','p0p3p9','p2p9p24','p0p2p9','p10p17','p2p9p10','p10k49','p3k56','p17p24','p2p9k56','p0p9p17','k50','p24k49','p0p9k49','p4','k49k56','p17k56']

def test_get_monomial_occurences():
cipher = GastonSboxPermutation(number_of_rounds=1)
# Return the monomials of degree 2 of the anf of the chosen output bit
cipher = SimonBlockCipher(number_of_rounds=2)
milp = MilpDivisionTrailModel(cipher)
component = cipher.get_component_from_id('sbox_0_30')
anfs = milp.get_anfs_from_sbox(component)
assert len(anfs) == 5
monomials = milp.find_anf_of_specific_output_bit(0, fixed_degree=2)
assert monomials ==['p17p24', 'p0p9k49', 'p3p24', 'p2p9k56', 'p10p17']

def test_find_degree_of_specific_output_bit():
# Return the degree of the anf of the chosen output bit of the ciphertext
cipher = AradiBlockCipher(number_of_rounds=1)
milp = MilpDivisionTrailModel(cipher)
degree = milp.find_degree_of_specific_output_bit(0)
assert degree == 3

# Return the degree of the anf of the chosen output bit of the component xor_0_12
cipher = AradiBlockCipher(number_of_rounds=1)
milp = MilpDivisionTrailModel(cipher)
degree = milp.find_degree_of_specific_output_bit(0, chosen_cipher_output='xor_0_12')
assert degree == 3

cipher = SpeckBlockCipher(number_of_rounds=1)
milp = MilpDivisionTrailModel(cipher)
degree = milp.find_degree_of_specific_output_bit(15)
Expand All @@ -34,4 +50,16 @@ def test_find_degree_of_specific_output_bit():
cipher = GastonSboxPermutation(number_of_rounds=1)
milp = MilpDivisionTrailModel(cipher)
degree = milp.find_degree_of_specific_output_bit(0)
assert degree == 2
assert degree == 2

cipher = MidoriBlockCipher(number_of_rounds=2)
milp = MilpDivisionTrailModel(cipher)
degree = milp.find_degree_of_specific_output_bit(0)
assert degree == 8

def test_check_presence_of_particular_monomial_in_specific_anf():
# Return the all monomials that contains p230 of the anf of the chosen output bit
cipher = GastonSboxPermutation(number_of_rounds=1)
milp = MilpDivisionTrailModel(cipher)
monomials = milp.check_presence_of_particular_monomial_in_specific_anf([("plaintext", 230)], 0)
assert monomials == ['p181p230','p15p230','p33p230','p54p230','p55p230','p82p230','p100p230','p114p230','p115p230','p128p230','p140p230','p141p230','p146p230','p223p230','p205p230','p209p230','p210p230','p230p267','p230p313','p230p314','p230p315']
Loading