Skip to content

Commit

Permalink
FEATURE/Add: Adding semi-deterministic truncated model
Browse files Browse the repository at this point in the history
Adding cost objective on SAT model
  • Loading branch information
juaninf committed Jan 24, 2025
1 parent 8ed60a0 commit 02b2a2b
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,61 @@
from claasp.cipher_modules.models.sat.sat_model import SatModel
from claasp.cipher_modules.models.utils import set_component_solution
from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL,
INTERMEDIATE_OUTPUT, INPUT_PLAINTEXT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION)

INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION)


def group_triples(var_names):
"""
Given a list of variable names (strings) of the form
hw_{p|q|r}_modadd_X_Y_Z_W
group them by (X,Y,Z,W) and return a dict:
grouped[(X, Y, Z, W)] = (pVarName, qVarName, rVarName)
where pVarName, qVarName, rVarName are the corresponding names.
We assume each group has exactly three variables: one for p, q, and r.
"""
grouped = {}
for name in var_names:
# Example name: "hw_p_modadd_0_1_0_0"
# Split by '_'
parts = name.split('_')
# parts[1] = 'p' or 'q' or 'r'
# the block parts[3], parts[4], parts[5] correspond to X, Y, Z
# possibly parts[6] is W (if present)

bit_id = parts[1] # 'p', 'q', or 'r'

# we expect something like: parts = ["hw", "p", "modadd", X, Y, Z, W]
# e.g. "hw_p_modadd_0_1_3_0"
X = parts[3]
Y = parts[4]
Z = parts[5]
W = parts[6] if len(parts) > 6 else "0" # sometimes there's an extra index

key = (X, Y, Z, W)

if key not in grouped:
grouped[key] = {'p': None, 'q': None, 'r': None}
grouped[key][bit_id] = name

# Convert the dictionary of bit_id->name into a tuple (pName, qName, rName)
# for easier use later:
triples_dict = {}
for k, bit_map in grouped.items():
# each bit_map is e.g. {'p': 'hw_p_modadd_...', 'q': 'hw_q_modadd_...', 'r': 'hw_r_modadd_...'}
p_name = bit_map['p']
q_name = bit_map['q']
r_name = bit_map['r']
triples_dict[k] = (p_name, q_name, r_name)

return triples_dict

class SatSemiDeterministicTruncatedXorDifferentialModel(SatModel):
def __init__(self, cipher, counter='sequential', compact=False):
super().__init__(cipher, counter, compact)

def build_semi_deterministic_truncated_xor_differential_trail_model(self, number_of_unknown_variables=None,
fixed_variables=[]):
def build_semi_deterministic_truncated_xor_differential_trail_model(
self, number_of_unknown_variables=None, weight=None, fixed_variables=[]
):
"""
Build the model for the search of deterministic truncated XOR DIFFERENTIAL trails.
Expand Down Expand Up @@ -72,11 +118,17 @@ def build_semi_deterministic_truncated_xor_differential_trail_model(self, number
self._variables_list.extend(variables)
self._model_constraints.extend(constraints)

if number_of_unknown_variables is not None:
# if number_of_unknown_variables is not None:
# variables, constraints = self.weight_constraints(number_of_unknown_variables)
# self._variables_list.extend(variables)
# self._model_constraints.extend(constraints)

if weight is not None:
variables, constraints = self.weight_constraints(number_of_unknown_variables)
self._variables_list.extend(variables)
self._model_constraints.extend(constraints)


@staticmethod
def fix_variables_value_constraints(fixed_variables=[]):
"""
Expand Down Expand Up @@ -147,8 +199,12 @@ def fix_variables_value_constraints(fixed_variables=[]):

return constraints

def find_one_semi_deterministic_truncated_xor_differential_trail(self, fixed_values=[],
solver_name=solvers.SOLVER_DEFAULT):
def find_one_semi_deterministic_truncated_xor_differential_trail(
self,
fixed_values=[],
solver_name=solvers.SOLVER_DEFAULT,
unknown_probability_weight_configuration=None
):
"""
Returns one deterministic truncated XOR differential trail.
Expand Down Expand Up @@ -194,6 +250,9 @@ def find_one_semi_deterministic_truncated_xor_differential_trail(self, fixed_val
"""
start_building_time = time.time()
self.build_semi_deterministic_truncated_xor_differential_trail_model(fixed_variables=fixed_values)
if unknown_probability_weight_configuration is not None:
self.weight_constraints(unknown_probability_weight_configuration)

end_building_time = time.time()
solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name)
solution['building_time_seconds'] = end_building_time - start_building_time
Expand Down Expand Up @@ -249,7 +308,7 @@ def find_lowest_varied_patterns_semi_deterministic_truncated_xor_differential_tr

return solution

def weight_constraints(self, number_of_unknown_variables):
def weight_constraints(self, configuration):
"""
Return lists of variables and constraints that fix the number of unknown
variables of the input and the output of the trail to a specific value.
Expand All @@ -274,19 +333,166 @@ def weight_constraints(self, number_of_unknown_variables):
'-cipher_output_2_12_30_0 -dummy_hw_0_61_3',
'-cipher_output_2_12_31_0 -dummy_hw_0_62_3'])
"""
cipher_output_id = self._cipher.get_all_components_ids()[-1]
set_to_be_minimized = [f"{INPUT_PLAINTEXT}_{i}_0"
for i in range(self._cipher.inputs_bit_size[self._cipher.inputs.index(INPUT_PLAINTEXT)])]
set_to_be_minimized.extend([bit_id for bit_id in self._variables_list
if bit_id.startswith(cipher_output_id) and bit_id.endswith("_0")])

return self._counter(set_to_be_minimized, number_of_unknown_variables)
max_number_of_sequences_window_size_0 = configuration['max_number_of_sequences_window_size_0']
max_number_of_sequences_window_size_1 = configuration['max_number_of_sequences_window_size_1']
max_number_of_sequences_window_size_2 = configuration['max_number_of_sequences_window_size_2']

hw_variables = [var_id for var_id in self._variables_list if var_id.startswith('hw_')]

def x_iff_abc_cnf(a: str, b: str, c: str, x: str) -> list:
"""
Generate CNF clauses for x <-> a, b, c.
Args:
a, b, c: Strings representing boolean variables (can include negations with '-').
x: String representing the boolean variable x (can include negations with '-').
Returns:
List of CNF clauses in f-string format where OR is represented by space and negations by '-'.
"""

def negate(var):
"""Return the negated form of a variable."""
return var[1:] if var.startswith("-") else f"-{var}"

clauses = [
f"{negate(x)} {a}", # -x OR a
f"{negate(x)} {b}", # -x OR b
f"{negate(x)} {c}", # -x OR c
f"{negate(a)} {negate(b)} {negate(c)} {x}" # -a OR -b OR -c OR x
]
return clauses

triples_dict = group_triples(hw_variables)
window_1_vars = []
window_2_vars = []
for tuple_key, tuple_value in triples_dict.items():
window_1_var = "hw_window_1" + "_".join(tuple_key)
window_1_vars.append(window_1_var)
constraints = x_iff_abc_cnf(
tuple_value[0], "-" + tuple_value[1], tuple_value[2], window_1_var
)
# import ipdb; ipdb.set_trace()
self._variables_list.extend([window_1_var])
self._model_constraints.extend(constraints)

window_2_var = "hw_window_2" + "_".join(tuple_key)
window_2_vars.append(window_2_var)
constraints = x_iff_abc_cnf(
tuple_value[0], "-" + tuple_value[1], "-" + tuple_value[2], window_2_var
)
self._variables_list.extend([window_2_var])
self._model_constraints.extend(constraints)
cardinality_variables_window_1, cardinality_constraints_window_1 = self._counter(
window_1_vars, max_number_of_sequences_window_size_1
)
self._model_constraints.extend(cardinality_constraints_window_1)
self._variables_list.extend(cardinality_variables_window_1)
cardinality_variables_window_2, cardinality_constraints_window_2 = self._counter(
window_2_vars, max_number_of_sequences_window_size_2
)
self._model_constraints.extend(cardinality_constraints_window_2)
self._variables_list.extend(cardinality_variables_window_2)

def _calculate_component_weight(self, component, variable2value):
def map_dicts(dict1, dict2):
"""
Map values from dict1 to keys defined in dict2.
Args:
dict1 (dict): A dictionary containing keys and their corresponding values.
dict2 (dict): A dictionary where keys are tuples and values are tuples of keys from dict1.
Returns:
dict: A dictionary where each key from dict2 maps to a sub-dictionary of values from dict1.
"""
result = {}

for key, triplets in dict2.items():
# Create a sub-dictionary for each key in dict2
try:
sub_dict = {triplet: dict1.get(triplet, None) for triplet in triplets}
except:
import ipdb;
ipdb.set_trace()
result[key] = sub_dict

return result

def get_probability_expressions(input_dict):
"""
Process the input dictionary and calculate the number of times
each (P, Q, R) combination occurs.
Args:
input_dict (dict): A dictionary where each key is a tuple,
and values are dictionaries with keys P, Q, R.
Returns:
dict: A dictionary where keys are decimal equivalents of binary
representations (P, Q, R), and values are the counts of occurrences.
"""
# Initialize a dictionary to store counts for each combination
counts = {0: 0, 4: 0, 9: 0, 19: 0, 41: 0, 100: 0}

# Iterate over the input dictionary
for key, values in input_dict.items():
# Extract the values of P, Q, R
if list(values.keys()) == [None] and list(values.values()) == [None]:
continue

p = next((v for k, v in values.items() if "p_modadd" in k), None)
q = next((v for k, v in values.items() if "q_modadd" in k), None)
r = next((v for k, v in values.items() if "r_modadd" in k), None)

# Ensure all three are present
if p is not None and q is not None and r is not None:
# Convert P, Q, R to binary and calculate decimal equivalent
binary = f"{p}{q}{r}"
decimal = int(binary, 2)

# Map binary decimal equivalent to the desired output format
decimal_map = {
0: 0, # 000
1: 4, # 001
2: 9, # 010
3: 19, # 011
4: 41, # 100
5: 100, # 101
}

# Increment the count for the mapped value
if decimal in decimal_map:
counts[decimal_map[decimal]] += 1

return counts

weight = 0
if ('MODSUB' in component.description or 'MODADD' in component.description or 'AND' in component.description
or 'OR' in component.description or SBOX in component.type):

hw_variables = [var_id for var_id in self._variables_list if var_id.startswith('hw_')]
hw_variables = [var_id for var_id in hw_variables if component.id in var_id]
triples_dict = group_triples(hw_variables)

result_triples = map_dicts(variable2value, triples_dict)
probability_counts = get_probability_expressions(result_triples)

for key, value in probability_counts.items():
weight += value * (key / 100)

return weight


def _parse_solver_output(self, variable2value):
components_solutions = self._get_cipher_inputs_components_solutions_double_ids(variable2value)
total_weight = 0
for component in self._cipher.get_all_components():
value = self._get_component_value_double_ids(component, variable2value)
component_solution = set_component_solution(value)
weight = self._calculate_component_weight(component, variable2value)
total_weight += weight
component_solution = set_component_solution(value, weight)
components_solutions[f'{component.id}'] = component_solution

return components_solutions, None
return components_solutions, total_weight
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
SatSemiDeterministicTruncatedXorDifferentialModel
from claasp.cipher_modules.models.utils import set_fixed_variables
from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher
from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation


def test_find_one_semi_deterministic_truncated_xor_differential_trail():
Expand Down Expand Up @@ -32,3 +33,93 @@ def test_find_one_semi_deterministic_truncated_xor_differential_trail():

assert trail['components_values']['cipher_output_2_12']['value'] == '???????????????0????????????????'


def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration():
speck = SpeckBlockCipher(number_of_rounds=3)
sat = SatSemiDeterministicTruncatedXorDifferentialModel(speck)

plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32),
bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0])

intermediate_output_0_6 = set_fixed_variables(
component_id='intermediate_output_0_6', constraint_type='equal', bit_positions=range(32),
bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

intermediate_output_1_12 = set_fixed_variables(
component_id='intermediate_output_1_12', constraint_type='equal', bit_positions=range(32),
bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1])

cipher_output_2_12 = set_fixed_variables(
component_id='cipher_output_2_12', constraint_type='equal', bit_positions=range(32),
bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1])

key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64),
bit_values=(0,) * 64)
trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail(
fixed_values=[plaintext, intermediate_output_0_6, intermediate_output_1_12, cipher_output_2_12, key],
unknown_probability_weight_configuration={
"max_number_of_sequences_window_size_0": 20,
"max_number_of_sequences_window_size_1": 20,
"max_number_of_sequences_window_size_2": 20
}
)

print(trail)

assert trail['components_values']['cipher_output_2_12']['value'] == '????0??????????0???????????????1'


def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration_unsat():
speck = SpeckBlockCipher(number_of_rounds=3)
sat = SatSemiDeterministicTruncatedXorDifferentialModel(speck)

plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32),
bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0])

intermediate_output_0_6 = set_fixed_variables(
component_id='intermediate_output_0_6', constraint_type='equal', bit_positions=range(32),
bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

intermediate_output_1_12 = set_fixed_variables(
component_id='intermediate_output_1_12', constraint_type='equal', bit_positions=range(32),
bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1])

cipher_output_2_12 = set_fixed_variables(
component_id='cipher_output_2_12', constraint_type='equal', bit_positions=range(32),
bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1])

key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64),
bit_values=(0,) * 64)
trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail(
fixed_values=[plaintext, intermediate_output_0_6, intermediate_output_1_12, cipher_output_2_12, key],
unknown_probability_weight_configuration={
"max_number_of_sequences_window_size_0": 20,
"max_number_of_sequences_window_size_1": 1,
"max_number_of_sequences_window_size_2": 20
}
)

assert trail['status'] == 'UNSATISFIABLE'


def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration_chacha():
chacha = ChachaPermutation(number_of_rounds=6)
sat = SatSemiDeterministicTruncatedXorDifferentialModel(chacha)
state_size = 512
initial_state = [0] * state_size
initial_state[389] = 1
plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(state_size),
bit_values=initial_state)

trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail(
fixed_values=[plaintext],
unknown_probability_weight_configuration={
"max_number_of_sequences_window_size_0": 20,
"max_number_of_sequences_window_size_1": 1,
"max_number_of_sequences_window_size_2": 20
}
)

assert trail['status'] == 'UNSATISFIABLE'

0 comments on commit 02b2a2b

Please sign in to comment.