There is an old truism: “Use the right tool for the job.” However, in building software, we are often forced to nail in screws, just because the rest of the application was built with the figurative hammer Java. Of course, one of the preferred solutions to this problem is microservices, which each handle one task and can be written in the most suitable language.
But what to do if the monolith already exists, or the project is not large enough to justify the increased complexity of microservices? Well, in this case, if tight coupling is unavoidable or even preferred, we can use the approach I am going to show in this blog post. We will learn how to use the machine learning ecosystem of Python to apply reinforcement learning to a system implemented in Java. After training, we can load the model into Java and use it. Therefore, we only use Python during the training and not in production. What’s best about this approach is that it ensures a happy data scientist who can use the right tools for the job.
And since this is about Python: What would be a better toy example than the classic game Snake? (The answer to this rhetorical question is, of course: “Some reference to Monty Python.” But I really could not think of a simple problem about a flying circus.)
The complete source code of our example is available on GitHub.
Snake in Java
We are starting with a Java program implementing the game logic of Snake: there is always a piece of food on the field. Whenever the snake reaches the food, it grows and new food appears. If the snake bites itself or a wall, the game ends.
Our objective is to train a neural net to steer the snake such that the snake eats as much food as possible before it makes a mistake and the game ends. First, we need a tensor which represents the current state of the game. It acts as the input of our neural net, such that the net can use it to predict the best next step to take. To keep this example simple, our tensor is just a vector of seven elements, which can either be 1 or 0: the first four indicate if the food is right, left, in front of or behind the snake and the next three entries signal if the fields left, in front of and right of the snake’s head are blocked by a wall or the tail of the snake.
1public class SnakeLogic {
2 Coordinate head; // position of the snake's head
3 Coordinate food; // position of the food
4 Move headDirection; // direction in which the head points
5
6 public boolean[] trainingState() {
7 boolean[] state = new boolean[7];
8
9 // get the angle from the head to the food,
10 // depending on the direction of movement `headDirection`
11 double alpha = angle(head, headDirection, food);
12
13 state[0] = isFoodFront(alpha);
14 state[1] = isFoodLeft(alpha);
15 state[2] = isFoodRight(alpha);
16 state[3] = isFoodBack(alpha);
17
18 // check if there is danger on these sites
19 state[4] = danger(head.left(headDirection));
20 state[5] = danger(head.straight(headDirection));
21 state[6] = danger(head.right(headDirection));
22
23 return state;
24 }
25
26 // omitted other fields and methods for clarity
27 // find them at https://github.com/surt91/autosnake
28}
We will need this method on two occasions. First, during the training, where we will call it directly from Python. And later in production, where we will call it from our Java program to give the trained net a basis for making a decision.
Java classes in Python
Enter JPype ! The import of a Java class — without any changes to the Java sources — can be accomplished simply with the following code:
1import jpype 2import jpype.imports 3from jpype.types import * 4 5# launch the JVM 6jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar']) 7 8# import the Java module 9from me.schawe.autosnake import SnakeLogic 10 11# construct an object of the `SnakeLogic` class ... 12width, height = 10, 10 13snake_logic = SnakeLogic(width, height) 14 15# ... and call a method on it 16print(snake_logic.trainingState())
JPype starts a JVM in the same process as the Python interpreter and lets them communicate using the Java Native Interface (JNI). One can think about it, in a simplified way, like calling functions from dynamic libraries (experienced Pythonistas may find a comparison to the module ctypes helpful.) But JPype does this in a very comfortable way and automatically maps Java classes on Python classes.
It should also be noted that there is a surprising number of projects with this objective, each with their own strengths and weaknesses. As representatives, we will quickly look at Jython and Py4J.
Jython executes a Python Interpreter directly in the JVM, such that Python and Java can very efficiently use the same data structures. But this comes with a few drawbacks for the usage of native Python libraries — since we will use numpy and tensorflow, this is not an option for us.
Py4J is on the opposite side of the spectrum. It starts a socket in the Java code, over which it can communicate with Python programs. The advantage is that an arbitrary number of Python processes can connect to a long-running Java process — or the other way around, one Python process can connect to many JVMs, even over the network. The downside is a larger overhead of the socket communication.
The training
Now that we can access our Java classes, we can use the deep learning framework of our choice — in our case, Keras — to create and train a model. Since we want to train a snake to collect the maximum amount of food, we choose a reinforcement learning approach.
In reinforcement learning an agent interacts with an environment and is rewarded for good decisions and punished for bad decisions. In the past, this discipline has drawn quite some media attention for playing classic Atari games or Go .
For our application, it makes sense to write a class that adheres closely to the OpenAI Gyms , since they are a de facto standard for reinforcement learning.
Therefore we need a method step
, which takes an action
, simulates a time step, and returns the result of the action. The action
is the output of the neural net and suggests whether the snake should turn left or right or not at all. The returned result consists of
state
, the new state (our vector with seven elements),reward
our valuation of the action: 1 if the snake could eat food in this step, -1 if the snake bit itself or a wall and else 0. Anddone
, an indicator whether the round is finished, i.e. if the snake bit itself or a wall. Also- a dictionary with debugging information, which we just leave empty.
Further, we need a method reset
to start a new round. It should also return the new state.
Both methods are easy to write thanks to our already existing Java classes:
1import jpype
2import jpype.imports
3from jpype.types import *
4
5# Launch the JVM
6jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
7
8# import the Java module
9from me.schawe.autosnake import SnakeLogic
10
11
12class Snake:
13 def __init__(self):
14 width, height = 10, 10
15 # `snakeLogic` is a Java object, such that we can call
16 # all its methods. This is also the reason why we
17 # name it in camelCase instead of the snake_case
18 # convention of Python.
19 self.snakeLogic = SnakeLogic(width, height)
20
21 def reset(self):
22 self.snakeLogic.reset()
23
24 return self.snakeLogic.trainingState()
25
26 def step(self, action):
27 self.snakeLogic.turnRelative(action)
28 self.snakeLogic.update()
29
30 state = self.snakeLogic.trainingState()
31
32 done = False
33 reward = 0
34 if self.snakeLogic.isGameOver():
35 reward = -1
36 done = True
37 elif self.snakeLogic.isEating():
38 reward = 1
39
40 return state, reward, done, {}
Now, we can easily insert this training environment into the first example for reinforcement learning of the Keras documentation and directly use it to start the training:
The snake learns! Within a few minutes, it begins to directly move towards the food and avoids the walls — but it still tends to trap itself quickly. For our purposes this should suffice for now.
Load the model in Java
Now we come full circle and load the trained model into Java using deeplearning4j …
1// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import
2public class Autopilot {
3 ComputationGraph model;
4
5 public Autopilot(String pathToModel) {
6 try {
7 model = KerasModelImport.importKerasModelAndWeights(pathToModel, false);
8 } catch (Exception e) {
9 e.printStackTrace();
10 }
11 }
12
13 // infer the next move from the given state
14 public int nextMove(boolean[] state) {
15 INDArray input = Nd4j.create(state).reshape(1, state.length);
16 INDArray output = model.output(input)[0];
17
18 int action = output.ravel().argMax().getInt(0);
19
20 return action;
21 }
22}
… where we call the same methods used during training to steer the snake.
1public class SnakeLogic {
2 Autopilot autopilot = new Autopilot("path/to/model.h5");
3
4 public void update() {
5 int action = autopilot.nextMove(trainingState());
6 turnRelative(action);
7
8 // rest of the update omitted
9 }
10
11 // further methods omitted
12}
Conclusion
It is surprisingly easy to make Java and Python work hand in hand, which can be especially valuable when developing prototypes.
What’s more, it does not have to be deep learning. Since the connection between Java and Python is so easy to use, there certainly is potential to apply this approach to facilitate explorative data analysis on a database using the full business logic in an iPython notebook.
Regarding our toy example: Given that we did not spend a single thought on the model, the result is surprisingly good. For better results, one probably would have to use the full field as input and think a bit more about the model. Quick googling shows that
apparently there are models which can play a perfect game of Snake, such that the snake occupies every single site. For Snake, it might be more sensible to use the neural net between one’s ears to think of a perfect strategy. For example, we can ensure a
perfect game if the snake always moves on a Hamilton path between its head and the tip of its tail (i.e. a path which visits all sites except those occupied by the snake). How to find Hamilton paths efficiently will be left to the reader as an exercise.
Your job at codecentric?
Jobs
Agile Developer und Consultant (w/d/m)
Alle Standorte
More articles in this subject area
Discover exciting further topics and let the codecentric world inspire you.
Gemeinsam bessere Projekte umsetzen.
Wir helfen deinem Unternehmen.
Du stehst vor einer großen IT-Herausforderung? Wir sorgen für eine maßgeschneiderte Unterstützung. Informiere dich jetzt.
Hilf uns, noch besser zu werden.
Wir sind immer auf der Suche nach neuen Talenten. Auch für dich ist die passende Stelle dabei.
Blog author
Hendrik Schawe
IT-Consultant
Do you still have questions? Just send me a message.
Do you still have questions? Just send me a message.