201 lines
6.0 KiB
Python
201 lines
6.0 KiB
Python
|
import hashlib
|
||
|
import math
|
||
|
import numpy as np
|
||
|
import random
|
||
|
|
||
|
def count_one_bits(n):
|
||
|
return bin(n).count("1")
|
||
|
|
||
|
def xor_n(n):
|
||
|
return count_one_bits(n) % 2
|
||
|
|
||
|
def sha_n(n):
|
||
|
m = hashlib.sha256()
|
||
|
m.update(str(n).encode("utf-8"))
|
||
|
result = m.digest()
|
||
|
return result[0] & 0b1
|
||
|
|
||
|
def xor_by_index(knowns, index, reverse=False):
|
||
|
mask = 1 << index
|
||
|
knowns = knowns[:]
|
||
|
for i in range(len(knowns)):
|
||
|
(g, j, value) = knowns[i]
|
||
|
if j & mask or (not (j & mask) and reverse):
|
||
|
knowns[i] = (g, j, value ^ 1)
|
||
|
return knowns
|
||
|
|
||
|
def remove_bit(i, n):
|
||
|
return (i & ((1 << n) - 1)) | ((i & ~((1 << (n + 1)) - 1)) >> 1)
|
||
|
|
||
|
def split_at(knowns, N, i):
|
||
|
mask = 1 << i
|
||
|
left = [(g, remove_bit(j, i), value) for (g, j, value) in knowns if (j & mask) == 0]
|
||
|
right = [(g, remove_bit(j, i), value) for (g, j, value) in knowns if not (j & mask) == 0]
|
||
|
return (left, right)
|
||
|
|
||
|
def factor_at(knowns, N, i, identity_value=1):
|
||
|
mask = 1 << i
|
||
|
left = [(g, j, value) for (g, j, value) in knowns if value == identity_value or (j & mask) == 0]
|
||
|
right = [(g, j, value) for (g, j, value) in knowns if value == identity_value or not (j & mask) == 0]
|
||
|
return (left, right)
|
||
|
|
||
|
def key_for_knowns(knowns):
|
||
|
return tuple([g for (g, _, _) in knowns])
|
||
|
|
||
|
primes = [1, 2, 3, 5, 7, 11, 13, 17, 19, 23]
|
||
|
|
||
|
def compute_split_knowns_r(knowns, N):
|
||
|
stack = [(knowns, N)]
|
||
|
numerator = 0.0
|
||
|
denominator = 0.0
|
||
|
|
||
|
while len(stack) > 0:
|
||
|
(s, n) = stack.pop()
|
||
|
depth = (N - n)
|
||
|
weight = depth ** 64
|
||
|
|
||
|
if len(s) == 1:
|
||
|
# numerator += weight
|
||
|
# denominator += weight
|
||
|
numerator += weight
|
||
|
denominator += weight
|
||
|
continue
|
||
|
if len(s) == 2:
|
||
|
(_, a, left_value) = s[0]
|
||
|
(_, b, right_value) = s[1]
|
||
|
distance = count_one_bits(a ^ b)
|
||
|
weight /= (2 ** distance)
|
||
|
if left_value == right_value:
|
||
|
numerator += weight
|
||
|
denominator += weight
|
||
|
else:
|
||
|
denominator += weight
|
||
|
continue
|
||
|
|
||
|
for i in range(0, n):
|
||
|
(left, right) = split_at(s, n, i)
|
||
|
if len(left) == 0 or len(right) == 0:
|
||
|
continue
|
||
|
stack.append((left, n - 1))
|
||
|
stack.append((right, n - 1))
|
||
|
|
||
|
return numerator / denominator
|
||
|
|
||
|
def invert(knowns):
|
||
|
inverted_knowns = []
|
||
|
for (i, value) in knowns:
|
||
|
inverted_knowns.append((i, 1 - value))
|
||
|
return inverted_knowns
|
||
|
|
||
|
def reduce(knowns, N):
|
||
|
flips = []
|
||
|
best_coherence = compute_split_knowns_r(knowns, N)
|
||
|
print(best_coherence)
|
||
|
print(knowns)
|
||
|
print()
|
||
|
while best_coherence < 1.0:
|
||
|
best_index = -1
|
||
|
best_reverse = False
|
||
|
# best_coherence = 0
|
||
|
for i in range(0, N):
|
||
|
for reverse in [False, True]:
|
||
|
mutated_knowns = xor_by_index(knowns, i, reverse)
|
||
|
# coherence = compute_coherence(mutated_knowns, N)
|
||
|
coherence = compute_split_knowns_r(mutated_knowns, N)
|
||
|
print(i, reverse, coherence)
|
||
|
if coherence > best_coherence:
|
||
|
best_coherence = coherence
|
||
|
best_index = i
|
||
|
best_reverse = reverse
|
||
|
if best_index < 0:
|
||
|
break
|
||
|
knowns = xor_by_index(knowns, best_index, best_reverse)
|
||
|
flips.append((best_index, best_reverse))
|
||
|
print()
|
||
|
print(best_index, best_reverse, best_coherence)
|
||
|
print(knowns)
|
||
|
print()
|
||
|
return (knowns, best_coherence, flips)
|
||
|
|
||
|
def solve(knowns, N):
|
||
|
(knowns, coherence, flips) = reduce(knowns, N)
|
||
|
if coherence == 1.0:
|
||
|
(_, _, inverted) = knowns[0]
|
||
|
return (inverted, flips, None)
|
||
|
|
||
|
raise Exception('Stop')
|
||
|
|
||
|
best_coherence = 0
|
||
|
best_index = -1
|
||
|
best_identity_value = False
|
||
|
print()
|
||
|
for i in range(0, N):
|
||
|
for identity_value in [0, 1]:
|
||
|
coherence = compute_coherence(factor_at(knowns, N, i, identity_value), N)
|
||
|
print(i, identity_value, coherence)
|
||
|
if coherence > best_coherence:
|
||
|
best_coherence = coherence
|
||
|
best_index = i
|
||
|
best_identity_value = identity_value
|
||
|
print()
|
||
|
(left, right) = factor_at(knowns, N, best_index, best_identity_value)
|
||
|
return (0, flips, (best_identity_value, solve(left, N), solve(right, N)))
|
||
|
|
||
|
def evaluate(model, n, value = 0):
|
||
|
(inverted, flips, child) = model
|
||
|
for (i, invert) in flips:
|
||
|
mask = (1 << i)
|
||
|
masked_n = n & mask
|
||
|
if (masked_n > 0 and not invert) or (masked_n == 0 and invert):
|
||
|
value = 1 - value
|
||
|
if not child is None:
|
||
|
(identity, left_child, right_child) = child
|
||
|
left = evaluate(left_child, n, 1 - identity)
|
||
|
right = evaluate(right_child, n, 1 - identity)
|
||
|
if left and right:
|
||
|
value = 1 - value
|
||
|
if identity == 0:
|
||
|
value = 1 - value
|
||
|
if inverted:
|
||
|
value = 1 - value
|
||
|
return value
|
||
|
|
||
|
def main():
|
||
|
N = 8
|
||
|
S = 2 ** N
|
||
|
train_size = 128
|
||
|
test_size = 100
|
||
|
f = xor_n
|
||
|
|
||
|
knowns = [(i, i, f(i)) for i in [
|
||
|
# 0, 1, 2, 3, 4, 5, 6, 7
|
||
|
# 0, 3, 4, 5, 7
|
||
|
# 3, 5, 6, 10, 12, 14
|
||
|
# 1, 3, 7, 10, 14, 15
|
||
|
# 0, 3, 5, 6, 10, 11, 12
|
||
|
0, 3, 5, 6, 10, 11, 12, 24, 30
|
||
|
# 0, 3, 5, 6, 10, 11, 12, 24, 30, 52, 63, 255, 243, 127
|
||
|
# 128, 131, 248, 0, 7, 13, 17, 19
|
||
|
# 23, 38, 46, 89, 108, 110, 114, 119, 137, 168, 177, 201, 206, 232, 247, 255
|
||
|
]]
|
||
|
|
||
|
# knowns = []
|
||
|
# train_samples = set()
|
||
|
# for i in range(0, train_size):
|
||
|
# k = random.randint(0, S)
|
||
|
# while k in train_samples:
|
||
|
# k = random.randint(0, S)
|
||
|
# knowns.append((k, f(i)))
|
||
|
# train_samples.add(k)
|
||
|
|
||
|
model = solve(knowns, N)
|
||
|
# print(model)
|
||
|
correct = 0
|
||
|
for i in range(0, test_size):
|
||
|
k = random.randint(0, S - 1)
|
||
|
if f(k) == evaluate(model, k):
|
||
|
correct += 1
|
||
|
print(str(correct) + "/" + str(test_size))
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|