Skip to content

Commit

Permalink
added experiments with xor gate
Browse files Browse the repository at this point in the history
  • Loading branch information
dhavala committed Oct 21, 2024
1 parent 26f12fc commit d0ce036
Showing 1 changed file with 249 additions and 28 deletions.
277 changes: 249 additions & 28 deletions boolgrad/notebooks/xor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand All @@ -22,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -370,7 +379,7 @@
"from bnn import ProdTron, SumTron, ProdLayer, SumLayer\n",
"\n",
"# can we learn an xor gate?\n",
"def optim_xor(x,w,y, model):\n",
"def optim_xor(x,w,y):\n",
" h = [ xi^wi for xi, wi in zip(x,w)]\n",
" yh = np.prod(h) # \n",
" yhd = [yh.data]\n",
Expand Down Expand Up @@ -506,7 +515,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -611,42 +620,254 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"***\n",
"\n",
"input: -1 -1\n",
"y: -1\n",
"model: -1 -1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: -1 1\n",
"y: -1\n",
"model: -1 -1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: 1 -1\n",
"y: 1\n",
"model: -1 -1\n",
"model is WRONG.\n",
"# 1 and got UPDATED +++\n",
"model after: [-1, 1]\n",
"pred before: -1 \n",
"pred after: 1\n",
"\n",
"***\n",
"\n",
"input: 1 1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"# mistakes 0\n",
"\n",
"***\n",
"\n",
"input: -1 -1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: -1 1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: 1 -1\n",
"y: 1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: 1 \n",
"pred after: 1\n",
"\n",
"***\n",
"\n",
"input: 1 1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"# mistakes 0\n",
"\n",
"***\n",
"\n",
"input: -1 -1\n",
"y: -1\n",
"model: 1 -1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: -1 1\n",
"y: -1\n",
"model: 1 -1\n",
"model is WRONG.\n",
"# 1 and got UPDATED +++\n",
"model after: [-1, -1]\n",
"pred before: 1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: 1 -1\n",
"y: 1\n",
"model: -1 -1\n",
"model is WRONG.\n",
"# 1 and got UPDATED +++\n",
"model after: [-1, 1]\n",
"pred before: -1 \n",
"pred after: 1\n",
"\n",
"***\n",
"\n",
"input: 1 1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"# mistakes 0\n",
"\n",
"***\n",
"\n",
"input: -1 -1\n",
"y: -1\n",
"model: -1 -1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: -1 1\n",
"y: -1\n",
"model: -1 -1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"\n",
"***\n",
"\n",
"input: 1 -1\n",
"y: 1\n",
"model: -1 -1\n",
"model is WRONG.\n",
"# 1 and got UPDATED +++\n",
"model after: [-1, 1]\n",
"pred before: -1 \n",
"pred after: 1\n",
"\n",
"***\n",
"\n",
"input: 1 1\n",
"y: -1\n",
"model: -1 1\n",
"model is CORRECT\n",
"and NOT updated +++\n",
"pred before: -1 \n",
"pred after: -1\n",
"# mistakes 0\n",
"Truth Table can be learnt if all entire Truth Table is seen by the model\n"
]
}
],
"source": [
"# we should verify this for all different initializations of the weights\n",
"\n",
"# we want to realize x1*x2' term. The truth table for this term is\n",
"T = [[-1,-1,-1],[-1,1,-1],[1,-1,1],[1,1,-1]]\n",
"W = [[-1,-1],[-1,1],[1,-1],[-1,-1]]\n",
"\n",
"mistakes = 0\n",
"w = [Bool(-1),Bool(-1)]\n",
"\n",
"flag = False\n",
"# for some rows in the truth table, this model is wrong\n",
"# by looping through other data, will we be able to eventually update the model?\n",
"for wi in W:\n",
" w = [Bool(wi[0]),Bool(wi[-1])]\n",
" for element in T:\n",
" print('\\n***\\n')\n",
" x = [Bool(element[0]),Bool(element[1])]\n",
" y = Bool(element[2]) \n",
" print('input: ',x[0].data, x[1].data)\n",
" print('y: ',y.data)\n",
" print('model: ',w[0].data, w[1].data)\n",
" w,yhd,flips,mistake = optim_xor(x,w,y)\n",
" mistakes += mistake\n",
" print('# mistakes',mistakes) \n",
" \n",
" # see if it is correct for all inputs\n",
" for element in T:\n",
" x = [Bool(element[0]),Bool(element[1])]\n",
" yh = np.prod([ xi^wi for xi, wi in zip(x,w)])\n",
" if yh.data != element[2]:\n",
" flag = True\n",
" print('-- Failed --')\n",
"if flag:\n",
" print('Truth Table can not learnt')\n",
"else:\n",
" print('Truth Table can be learnt if all entire Truth Table is seen by the model')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The implication is, entire Truth Table must be shown to the optimizer for updating the weights.\n",
"However, when only single \"example\" is given, and model is wrong for that example - that error signal is not enough. It depends on the Truth Table state for that particular input to flip the model weights.\n",
"\n",
"for element in T:\n",
" print('\\n***\\n')\n",
" x = [Bool(element[0]),Bool(element[1])]\n",
" y = Bool(element[2]) \n",
" print('input: ',x[0].data, x[1].data)\n",
" print('y: ',y.data)\n",
" print('model: ',w[0].data, w[1].data)\n",
" w,yhd,flips,mistake = optim_xor(x,w,y)\n",
" mistakes += mistake\n",
"# final model\n",
"print(w)\n",
"When such gates are present in millions in neural networks, the opportunity to be in a bad state grow exponentially. \n",
"Therefore, either we have to update the gradient definition to be decideable or we have to use the gradients as \"noisy signal\" to drive the errors to be smaller, but can get stuck occasionally.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Exercise\n",
"\n",
"# see if it is correct for all inputs\n",
"for element in T:\n",
" print('\\n***\\n')\n",
" x = [Bool(element[0]),Bool(element[1])]\n",
" yh = np.prod([ xi^wi for xi, wi in zip(x,w)])\n",
" if yh.data == element[2]:\n",
" print('correct')\n",
" else:\n",
" print('wrong')"
"Consider an 2-ary MLP that is modeling an XOR Gate. For a certain combination of input and model weights, gradients are not decidable, arising from the decidability of the primitive Gates. \n",
"\n",
"Gradients, defined according to BOLD, are undecidable when any of its inputs are 0 (to be ignored). This impacts the feedback loop to control the error.\n",
"\n",
"How often can this happen? Does this depend on the topology of the network?\n",
"\n",
"Specifically, any term in the 2-ary XOR Gate will take the form $AND(XOR(x_1,w_1),XOR(x_2,w_2))$. Note the NOT gate is actually an XOR gate. $XOR(x,T)=\\neg x = NOT(x)$, $XOR(x,F) = x$. Therefore, with this 2-ary MLP, we can model any 2-ary Product terms of the SoP, which forms the backbone to model much general Truth Tables.\n",
"\n",
"When $x_i=w_i \\, \\& \\, x_1=x_2$, the model is $AND(\\neg x_1 \\neg x_2) = F$ for $x_i = T$. Local gradient of $AND(.,.)$ is 0 in this case. Therefore, error will not propagate backwards.\n",
"\n",
"Now consider an K-ary MLP formed of 2-ary MLP (of the form above) that is both deep and wide, randomly initialized. For the sake of discussion consider the architecture of constant width of $H$, with depth $D$ (not including the input and output layers). This network has $P = 2(KH + DH^2 + H)$. As a result, the network will have $2^P$ possible initial states. \n",
"\n",
"Of the $2^K$ possible inputs, $2^P$ random configurations, how many rows of the Truth Table can not be learnt with a single backprop iteration?"
]
}
],
Expand Down

0 comments on commit d0ce036

Please sign in to comment.