package games.scenarios;

import ec.util.MersenneTwisterFast;
import games.Board;
import games.BoardGame;
import games.GameMove;
import games.Player;

import java.util.List;

import cecj.app.othello.OthelloBoard;


public class SelfPlayTDLScenario implements GameScenario {

	private double prob;
	private Player player;
	private double learningRate;
	private MersenneTwisterFast random;

	public SelfPlayTDLScenario(MersenneTwisterFast random, Player player, double prob,
			double learningRate) {
		this.prob = prob;
		this.player = player;
		this.random = random;
		this.learningRate = learningRate;
	}

	public int play(BoardGame game) {
		while (!game.endOfGame()) {
			List<? extends GameMove> moves = game.findMoves();
			if (!moves.isEmpty()) {
				GameMove bestMove = null;
				if (random.nextBoolean(prob)) {
					game.makeMove(moves.get(random.nextInt(moves.size())));
				} else {
					double bestEval = Float.NEGATIVE_INFINITY;
					for (GameMove move : moves) {
						double eval = game.evalMove(player, move);
						if (eval > bestEval) {
							bestEval = eval;
							bestMove = move;
						}
					}

					Board previousBoard = game.getBoard().clone();
					game.makeMove(bestMove);
					updateEvaluationFunction(previousBoard, game);
				}
			}
			game.switchPlayer();
		}

		return game.getOutcome();
	}

	private void updateEvaluationFunction(Board previousBoard, BoardGame game) {
		double evalBefore = tanh(previousBoard.evaluate(player));
		double derivative = (1 - (evalBefore * evalBefore));
		double error;

		if (game.endOfGame()) {
			int result;
			if (game.getOutcome() > 0) {
				result = 1;
			} else if (game.getOutcome() < 0) {
				result = -1;
			} else {
				result = 0;
			}
			error = result - evalBefore;
		} else {
			double evalAfter = tanh(game.getBoard().evaluate(player));
			error = evalAfter - evalBefore;
		}

		double delta = learningRate * error * derivative;
		for (int row = 1; row <= OthelloBoard.size(); row++) {
			for (int col = 1; col <= OthelloBoard.size(); col++) {
				double w = player.getValue(row, col);
				player.setValue(row, col, w + (delta * previousBoard.getValueAt(row, col)));
			}
		}
	}

	private static double tanh(double x) {
		return 2 / (1 + Math.exp(-2 * x)) - 1;
	}
}
