You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.3 KiB
Python
59 lines
1.3 KiB
Python
import nn
|
|
import pickle
|
|
|
|
networks_filename = "networks.fiar"
|
|
|
|
def nextNNmove(network, game, player):
|
|
flatfield = game.getTransformedField()
|
|
player1 = []
|
|
player2 = []
|
|
for spot in flatfield:
|
|
if spot == 1:
|
|
player1.append(1)
|
|
player2.append(0)
|
|
elif spot == 2:
|
|
player1.append(0)
|
|
player2.append(1)
|
|
else:
|
|
player1.append(0)
|
|
player2.append(0)
|
|
|
|
out = network.calculateOutput(player1+player2)
|
|
preferences = []
|
|
for i in range(len(out)):
|
|
maxI = 0
|
|
maxVal = -20
|
|
for j in range(len(out)):
|
|
if out[j] > maxVal:
|
|
maxVal = out[j]
|
|
maxI = j
|
|
out[maxI] = -20
|
|
preferences.append(maxI)
|
|
|
|
for val in preferences:
|
|
try:
|
|
game.setStone(val, player)
|
|
except OverflowError:
|
|
pass
|
|
else:
|
|
break
|
|
return game
|
|
|
|
def createAndSave(num):
|
|
nets = []
|
|
for i in range(num):
|
|
nets.append(nn.NN(84, 168, 168, 84, 84, 42, 7))
|
|
saveAll(nets)
|
|
|
|
def saveAll(nets):
|
|
with open(networks_filename, 'wb') as networks_file:
|
|
pickle.dump(nets, networks_file)
|
|
|
|
def loadAll():
|
|
with open(networks_filename, 'rb') as networks_file:
|
|
nets = pickle.load(networks_file)
|
|
return nets
|
|
|
|
def loadSingle(num):
|
|
return loadAll()[num]
|