import nn import pickle networks_filename = "networks.fiar" def nextNNmove(network, game, player): flatfield = game.getTransformedField() player1 = [] player2 = [] for spot in flatfield: if spot == 1: player1.append(1) player2.append(0) elif spot == 2: player1.append(0) player2.append(1) else: player1.append(0) player2.append(0) out = network.calculateOutput(player1+player2) preferences = [] for i in range(len(out)): maxI = 0 maxVal = -20 for j in range(len(out)): if out[j] > maxVal: maxVal = out[j] maxI = j out[maxI] = -20 preferences.append(maxI) for val in preferences: try: game.setStone(val, player) except OverflowError: pass else: break return game def createAndSave(num): nets = [] for i in range(num): nets.append(nn.NN(84, 168, 168, 84, 84, 42, 7)) saveAll(nets) def saveAll(nets): with open(networks_filename, 'wb') as networks_file: pickle.dump(nets, networks_file) def loadAll(): with open(networks_filename, 'rb') as networks_file: nets = pickle.load(networks_file) return nets def loadSingle(num): return loadAll()[num]