Search in sources :

Example 6 with CellWorldEnvironment

use of aima.core.learning.reinforcement.example.CellWorldEnvironment in project aima-java by aimacode.

the class QLearningAgentTest method setUp.

@Before
public void setUp() {
    cw = CellWorldFactory.createCellWorldForFig17_1();
    cwe = new CellWorldEnvironment(cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer());
    qla = new QLearningAgent<Cell<Double>, CellWorldAction>(MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0);
    cwe.addAgent(qla);
}
Also used : CellWorldAction(aima.core.environment.cellworld.CellWorldAction) JavaRandomizer(aima.core.util.JavaRandomizer) CellWorldEnvironment(aima.core.learning.reinforcement.example.CellWorldEnvironment) Cell(aima.core.environment.cellworld.Cell) Before(org.junit.Before)

Example 7 with CellWorldEnvironment

use of aima.core.learning.reinforcement.example.CellWorldEnvironment in project aima-java by aimacode.

the class ReinforcementLearningAgentTest method test_RMSeiu_for_1_1.

public static void test_RMSeiu_for_1_1(ReinforcementAgent<Cell<Double>, CellWorldAction> reinforcementAgent, int numRuns, int numTrialsPerRun, double expectedErrorLessThan) {
    CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();
    CellWorldEnvironment cwe = new CellWorldEnvironment(cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer());
    cwe.addAgent(reinforcementAgent);
    Map<Integer, Map<Cell<Double>, Double>> runs = new HashMap<Integer, Map<Cell<Double>, Double>>();
    for (int r = 0; r < numRuns; r++) {
        reinforcementAgent.reset();
        cwe.executeTrials(numTrialsPerRun);
        runs.put(r, reinforcementAgent.getUtility());
    }
    // Calculate the Root Mean Square Error for utility of 1,1
    // for this trial# across all runs
    double xSsquared = 0;
    for (int r = 0; r < numRuns; r++) {
        Map<Cell<Double>, Double> u = runs.get(r);
        Double val1_1 = u.get(cw.getCellAt(1, 1));
        if (null == val1_1) {
            throw new IllegalStateException("U(1,1,) is not present: r=" + r + ", u=" + u);
        }
        xSsquared += Math.pow(0.705 - val1_1, 2);
    }
    double rmse = Math.sqrt(xSsquared / runs.size());
    Assert.assertTrue("" + rmse + " is not < " + expectedErrorLessThan, rmse < expectedErrorLessThan);
}
Also used : HashMap(java.util.HashMap) JavaRandomizer(aima.core.util.JavaRandomizer) CellWorldEnvironment(aima.core.learning.reinforcement.example.CellWorldEnvironment) Cell(aima.core.environment.cellworld.Cell) HashMap(java.util.HashMap) Map(java.util.Map)

Example 8 with CellWorldEnvironment

use of aima.core.learning.reinforcement.example.CellWorldEnvironment in project aima-java by aimacode.

the class ReinforcementLearningAgentTest method test_utility_learning_rates.

public static void test_utility_learning_rates(ReinforcementAgent<Cell<Double>, CellWorldAction> reinforcementAgent, int numRuns, int numTrialsPerRun, int rmseTrialsToReport, int reportEveryN) {
    if (rmseTrialsToReport > (numTrialsPerRun / reportEveryN)) {
        throw new IllegalArgumentException("Requesting to report too many RMSE trials, max allowed for args is " + (numTrialsPerRun / reportEveryN));
    }
    CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();
    CellWorldEnvironment cwe = new CellWorldEnvironment(cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer());
    cwe.addAgent(reinforcementAgent);
    Map<Integer, List<Map<Cell<Double>, Double>>> runs = new HashMap<Integer, List<Map<Cell<Double>, Double>>>();
    for (int r = 0; r < numRuns; r++) {
        reinforcementAgent.reset();
        List<Map<Cell<Double>, Double>> trials = new ArrayList<Map<Cell<Double>, Double>>();
        for (int t = 0; t < numTrialsPerRun; t++) {
            cwe.executeTrial();
            if (0 == t % reportEveryN) {
                Map<Cell<Double>, Double> u = reinforcementAgent.getUtility();
                if (null == u.get(cw.getCellAt(1, 1))) {
                    throw new IllegalStateException("Bad Utility State Encountered: r=" + r + ", t=" + t + ", u=" + u);
                }
                trials.add(u);
            }
        }
        runs.put(r, trials);
    }
    StringBuilder v4_3 = new StringBuilder();
    StringBuilder v3_3 = new StringBuilder();
    StringBuilder v1_3 = new StringBuilder();
    StringBuilder v1_1 = new StringBuilder();
    StringBuilder v3_2 = new StringBuilder();
    StringBuilder v2_1 = new StringBuilder();
    for (int t = 0; t < (numTrialsPerRun / reportEveryN); t++) {
        // Use the last run
        Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t);
        v4_3.append((u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw.getCellAt(4, 3)) : 0.0) + "\t");
        v3_3.append((u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw.getCellAt(3, 3)) : 0.0) + "\t");
        v1_3.append((u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw.getCellAt(1, 3)) : 0.0) + "\t");
        v1_1.append((u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw.getCellAt(1, 1)) : 0.0) + "\t");
        v3_2.append((u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw.getCellAt(3, 2)) : 0.0) + "\t");
        v2_1.append((u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw.getCellAt(2, 1)) : 0.0) + "\t");
    }
    System.out.println("(4,3)" + "\t" + v4_3);
    System.out.println("(3,3)" + "\t" + v3_3);
    System.out.println("(1,3)" + "\t" + v1_3);
    System.out.println("(1,1)" + "\t" + v1_1);
    System.out.println("(3,2)" + "\t" + v3_2);
    System.out.println("(2,1)" + "\t" + v2_1);
    StringBuilder rmseValues = new StringBuilder();
    for (int t = 0; t < rmseTrialsToReport; t++) {
        // Calculate the Root Mean Square Error for utility of 1,1
        // for this trial# across all runs
        double xSsquared = 0;
        for (int r = 0; r < numRuns; r++) {
            Map<Cell<Double>, Double> u = runs.get(r).get(t);
            Double val1_1 = u.get(cw.getCellAt(1, 1));
            if (null == val1_1) {
                throw new IllegalStateException("U(1,1,) is not present: r=" + r + ", t=" + t + ", runs.size=" + runs.size() + ", runs(r).size()=" + runs.get(r).size() + ", u=" + u);
            }
            xSsquared += Math.pow(0.705 - val1_1, 2);
        }
        double rmse = Math.sqrt(xSsquared / runs.size());
        rmseValues.append(rmse);
        rmseValues.append("\t");
    }
    System.out.println("RMSeiu" + "\t" + rmseValues);
}
Also used : HashMap(java.util.HashMap) JavaRandomizer(aima.core.util.JavaRandomizer) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) CellWorldEnvironment(aima.core.learning.reinforcement.example.CellWorldEnvironment) HashMap(java.util.HashMap) Map(java.util.Map) Cell(aima.core.environment.cellworld.Cell)

Example 9 with CellWorldEnvironment

use of aima.core.learning.reinforcement.example.CellWorldEnvironment in project aima-java by aimacode.

the class LearningDemo method passiveADPAgentDemo.

public static void passiveADPAgentDemo() {
    System.out.println("=======================");
    System.out.println("DEMO: Passive-ADP-Agent");
    System.out.println("=======================");
    System.out.println("Figure 21.3");
    System.out.println("-----------");
    CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();
    CellWorldEnvironment cwe = new CellWorldEnvironment(cw.getCellAt(1, 1), cw.getCells(), MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw), new JavaRandomizer());
    Map<Cell<Double>, CellWorldAction> fixedPolicy = new HashMap<Cell<Double>, CellWorldAction>();
    fixedPolicy.put(cw.getCellAt(1, 1), CellWorldAction.Up);
    fixedPolicy.put(cw.getCellAt(1, 2), CellWorldAction.Up);
    fixedPolicy.put(cw.getCellAt(1, 3), CellWorldAction.Right);
    fixedPolicy.put(cw.getCellAt(2, 1), CellWorldAction.Left);
    fixedPolicy.put(cw.getCellAt(2, 3), CellWorldAction.Right);
    fixedPolicy.put(cw.getCellAt(3, 1), CellWorldAction.Left);
    fixedPolicy.put(cw.getCellAt(3, 2), CellWorldAction.Up);
    fixedPolicy.put(cw.getCellAt(3, 3), CellWorldAction.Right);
    fixedPolicy.put(cw.getCellAt(4, 1), CellWorldAction.Left);
    PassiveADPAgent<Cell<Double>, CellWorldAction> padpa = new PassiveADPAgent<Cell<Double>, CellWorldAction>(fixedPolicy, cw.getCells(), cw.getCellAt(1, 1), MDPFactory.createActionsFunctionForFigure17_1(cw), new ModifiedPolicyEvaluation<Cell<Double>, CellWorldAction>(10, 1.0));
    cwe.addAgent(padpa);
    output_utility_learning_rates(padpa, 20, 100, 100, 1);
    System.out.println("=========================");
}
Also used : CellWorldAction(aima.core.environment.cellworld.CellWorldAction) HashMap(java.util.HashMap) JavaRandomizer(aima.core.util.JavaRandomizer) PassiveADPAgent(aima.core.learning.reinforcement.agent.PassiveADPAgent) CellWorldEnvironment(aima.core.learning.reinforcement.example.CellWorldEnvironment) Cell(aima.core.environment.cellworld.Cell)

Aggregations

Cell (aima.core.environment.cellworld.Cell)9 CellWorldEnvironment (aima.core.learning.reinforcement.example.CellWorldEnvironment)9 JavaRandomizer (aima.core.util.JavaRandomizer)9 HashMap (java.util.HashMap)7 CellWorldAction (aima.core.environment.cellworld.CellWorldAction)6 Map (java.util.Map)3 Before (org.junit.Before)3 ArrayList (java.util.ArrayList)2 List (java.util.List)2 PassiveADPAgent (aima.core.learning.reinforcement.agent.PassiveADPAgent)1 PassiveTDAgent (aima.core.learning.reinforcement.agent.PassiveTDAgent)1 QLearningAgent (aima.core.learning.reinforcement.agent.QLearningAgent)1 ModifiedPolicyEvaluation (aima.core.probability.mdp.impl.ModifiedPolicyEvaluation)1