From cebf7bd20784ebbb38b23586636a585d61cc0b1c Mon Sep 17 00:00:00 2001 From: Zeyu Yao Date: Tue, 2 May 2023 21:48:25 +0800 Subject: [PATCH 1/4] Reinitiate repository --- .gitignore | 4 - Main.java | 104 ++++++---------------- MoveSorter.java | 30 ------- Position.java | 189 +++++++++++----------------------------- README.md | 28 ------ Solver.java | 148 +++++++------------------------ TranspositionTable.java | 59 ------------- instructions.txt | 20 ----- util/ClearConsole.java | 8 -- util/ConsoleColors.java | 76 ---------------- 10 files changed, 113 insertions(+), 553 deletions(-) delete mode 100644 MoveSorter.java delete mode 100644 README.md delete mode 100644 TranspositionTable.java delete mode 100644 instructions.txt delete mode 100644 util/ClearConsole.java delete mode 100644 util/ConsoleColors.java diff --git a/.gitignore b/.gitignore index 84560c3..59c7d91 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,2 @@ # Test dataset from Pascal Pons' tutorial test - -# Test files -SolverTest.java -PlayTest.java \ No newline at end of file diff --git a/Main.java b/Main.java index d95cce3..0f51874 100644 --- a/Main.java +++ b/Main.java @@ -1,94 +1,44 @@ +import java.io.BufferedReader; +import java.io.FileReader; import java.util.Scanner; -import java.io.*; - -// custom import -import util.*; public class Main { - public static void main(String[] args) throws Exception { - // create objects - Position pos = new Position(); + public static void main(String[] args) { Solver solver = new Solver(); - Scanner in = new Scanner(System.in); - - // clear console - ClearConsole.clear(); + Scanner scanner = new Scanner(System.in); + System.out.print("Enter test batch to run: "); + String testBatch = scanner.nextLine(); + scanner.close(); + + try { + BufferedReader br = new BufferedReader(new FileReader("test/" + testBatch)); + String line; - // read in the text from instructions.txt and print the instructions - File file = new File("instructions.txt"); - BufferedReader br = new BufferedReader(new FileReader(file)); + System.out.println("Running test batch " + testBatch + "..."); - String line; - try { while ((line = br.readLine()) != null) { - System.out.println(line); - } - } catch (IOException e) { - e.printStackTrace(); - } + String[] split = line.split(" "); + String seq = split[0]; + int expectedScore = Integer.parseInt(split[1]); - br.close(); - - // newline - System.out.println(); + Position P = new Position(); + P.play(seq); - // generate a random move string of length between 10 and 20 - // keep generating a random move string until the string doesn't have any winning moves - int length = (int) (Math.random() * 10 + 10); - String moveString = ""; + int score = solver.solve(P); - while (true) { - Position sudoPosition = new Position(); - - for (int i = 0; i < length; i++) { - moveString += (int) (Math.random() * 7 + 1); - } - - int moveCount = sudoPosition.play(moveString); - if (moveCount == moveString.length()) { - break; - } else { - moveString = ""; + if (score != expectedScore) { + System.out.println("Error: " + seq + " " + score + " " + expectedScore); + br.close(); + return; + } } - } - // get user input for the move string and the heuristic setting - System.out.print("Do you want to use a random move string? (y/n): "); - String input = in.nextLine(); + br.close(); - if (input.equals("y")) { - System.out.println("Using random move string: " + moveString); - } else { - System.out.print("Enter the move string: "); - moveString = in.nextLine(); + System.out.println("Test batch " + testBatch + " passed!"); + } catch (Exception e) { + System.out.println(e); } - - System.out.print("Enter a heuristic setting (0 for weak, 1 for strong): "); - boolean weak = (in.nextInt() == 0); - - // newline - System.out.println(); - - System.out.println("Solving..."); - - // newline - System.out.println(); - - // play the move string - int moveCount = pos.play(moveString); - if (moveCount != moveString.length()) { - System.out.println("Invalid move: shutting down"); - } else { - // solve the position - int startTime = (int) System.currentTimeMillis(); - int score = solver.solve(pos, weak); - int endTime = (int) System.currentTimeMillis(); - - System.out.println("Position solved in " + pos.getMoves() + " moves in " + (endTime - startTime) + "ms: Score = " + score); - System.out.println(ConsoleColors.GREEN + "Optimal next move: column " + solver.chooseMove(pos, weak) + ConsoleColors.RESET); - } - - in.close(); } } diff --git a/MoveSorter.java b/MoveSorter.java deleted file mode 100644 index 1a92d61..0000000 --- a/MoveSorter.java +++ /dev/null @@ -1,30 +0,0 @@ -public class MoveSorter { - private int size; - - // Contains size moves with their score ordered by score - private Move[] entries = new Move[Position.WIDTH]; - - private class Move { - private long move; - public Move(long move, int score) { - this.move = move; - } - } - - public void add(long move, int score) { - int pos = size++; - // shift all entries to the right to make room for the new entry - for (int i = size - 1; i > pos; i--) { - entries[i] = entries[i - 1]; - } - entries[pos] = new Move(move, score); - } - - public long getNext() { - if (size != 0) { - return entries[--size].move; - } else { - return 0; - } - } -} diff --git a/Position.java b/Position.java index 5603c82..ce562ad 100644 --- a/Position.java +++ b/Position.java @@ -1,166 +1,83 @@ public class Position { - public static final int WIDTH = 7; // width of the board - public static final int HEIGHT = 6; // height of the board - public static final int MIN_SCORE = -(WIDTH * HEIGHT) / 2 + 3; - public static final int MAX_SCORE = (WIDTH * HEIGHT + 1) / 2 - 3; + public static final int WIDTH = 7; + public static final int HEIGHT = 6; - private long current_position; - private long mask; - private int moves; // number of moves played since the beinning of the game - - public static final long bottom(int width, int height) { - return width == 0 ? 0 : bottom(width - 1, height) | 1L << (width - 1) * (height + 1); - } - - /** Masks */ - private static final long bottom_mask = bottom(WIDTH, HEIGHT); - private static final long board_mask = bottom_mask * ((1L << HEIGHT) - 1); - - public static final long column_mask(int col) { - return ((1L << HEIGHT) - 1) << col * (HEIGHT + 1); - } - - private static final long top_mask_col(int col) { - return 1L << ((HEIGHT - 1) + col * (HEIGHT + 1)); - } - - private static final long bottom_mask_col(int col) { - return 1L << (col * (HEIGHT + 1)); - } + private int board[][] = new int[WIDTH][HEIGHT]; + private int tokens[] = new int[WIDTH]; + private int moves = 0; - /** Constructor */ public Position() { - current_position = 0; - mask = 0; + for (int i = 0; i < WIDTH; i++) { + tokens[i] = 0; + for (int j = 0; j < HEIGHT; j++) { + board[i][j] = 0; + } + } + moves = 0; } - public Position(Position pos) { - current_position = pos.current_position; - mask = pos.mask; - moves = pos.moves; + public Position(Position P) { + for (int i = 0; i < WIDTH; i++) { + tokens[i] = P.tokens[i]; + for (int j = 0; j < HEIGHT; j++) { + board[i][j] = P.board[i][j]; + } + } + + moves = P.moves; } - /** Play a move - * @param move the move to play - */ - public void play(long move) { - current_position ^= mask; - mask |= move; - moves++; + public boolean playable(int col) { + return (tokens[col] < HEIGHT); } public void play(int col) { - play(column_mask(col)); + board[col][tokens[col]] = moves % 2 + 1; + tokens[col]++; + moves++; } - /** Play a sequence of moves - * @param seq the sequence of moves to play - */ public int play(String seq) { for (int i = 0; i < seq.length(); i++) { - int column = seq.charAt(i) - '1'; - if (column < 0 || column >= Position.WIDTH || !canPlay(column) || isWinningMove(column)) return i; // invalid move - playColumn(column); + int col = Character.getNumericValue(seq.charAt(i)) - 1; + if (col < 0 || col >= WIDTH || !playable(col) || winsByPlaying(col)) { + return i; + } + play(col); } return seq.length(); } - /** Check if the current player can win in the next move */ - public boolean canWinNext() { - return (winning_position() & possible()) != 0; - } + public boolean winsByPlaying(int col) { + int player = (moves % 2) + 1; - public long possibleNonLosingMoves() { - long possible_mask = possible(); - long opponent_win = opponent_winning_position(); - long forced_moves = possible_mask & opponent_win; - - if (forced_moves != 0) { - if ((forced_moves & (forced_moves - 1)) != 0) { - return 0; - } else { - possible_mask = forced_moves; + if (tokens[col] >= 3) { + if (board[col][tokens[col] - 1] == player && board[col][tokens[col] - 2] == player + && board[col][tokens[col] - 3] == player) { + return true; } } - - return possible_mask & ~(opponent_win >> 1); - } - - public int moveScore(long move) { - return popcount(compute_winning_position(current_position | move, mask)); - } - - public boolean canPlay(int column) { - return (mask & top_mask_col(column)) == 0; - } - - /** Getters */ - public int getMoves() { - return moves; - } - - public long getKey() { - return current_position + mask; - } - - private void playColumn(int col) { - play((mask + bottom_mask_col(col)) & column_mask(col)); - } - - private boolean isWinningMove(int column) { - return (winning_position() & possible() & column_mask(column)) != 0; - } - private long winning_position() { - return compute_winning_position(current_position, mask); - } - - private long opponent_winning_position() { - return compute_winning_position(current_position ^ mask, mask); - } - - private long possible() { - return (mask + bottom_mask) & board_mask; - } - - private static int popcount(long m) { - int c = 0; + for (int dy = -1; dy <= 1; dy++) { + int aligned = 0; + for (int dx = -1; dx <= 1; dx += 2) { + for (int x = col + dx, y = tokens[col] + dx * dy; x >= 0 && x < WIDTH && y >= 0 && y < HEIGHT + && board[x][y] == player; aligned++) { + x += dx; + y += dx * dy; + } + } - for (c = 0; m != 0; c++) { - m &= m - 1; + if (aligned >= 3) { + return true; + } } - return c; + return false; } - private long compute_winning_position(long position, long mask) { - // vertical - long r = (position << 1) & (position << 2) & (position << 3); - - //horizontal - long p = (position << (HEIGHT + 1)) & (position << 2 * (HEIGHT + 1)); - r |= p & (position << 3 * (HEIGHT + 1)); - r |= p & (position >> (HEIGHT + 1)); - p = (position >> (HEIGHT + 1)) & (position >> 2 * (HEIGHT + 1)); - r |= p & (position << (HEIGHT + 1)); - r |= p & (position >> 3 * (HEIGHT + 1)); - - //diagonals - p = (position << HEIGHT) & (position << 2 * HEIGHT); - r |= p & (position << 3 * HEIGHT); - r |= p & (position >> HEIGHT); - p = (position >> HEIGHT) & (position >> 2 * HEIGHT); - r |= p & (position << HEIGHT); - r |= p & (position >> 3 * HEIGHT); - - p = (position << (HEIGHT + 2)) & (position << 2 * (HEIGHT + 2)); - r |= p & (position << 3 * (HEIGHT + 2)); - r |= p & (position >> (HEIGHT + 2)); - p = (position >> (HEIGHT + 2)) & (position >> 2 * (HEIGHT + 2)); - r |= p & (position << (HEIGHT + 2)); - r |= p & (position >> 3 * (HEIGHT + 2)); - - return r & (board_mask ^ mask); + public int getMoves() { + return moves; } -} \ No newline at end of file +} diff --git a/README.md b/README.md deleted file mode 100644 index 3ea7aee..0000000 --- a/README.md +++ /dev/null @@ -1,28 +0,0 @@ -
- 🚀 I'm rebooting this project! -

I'm building this project (again) from the ground up, this time using a more object-oriented approach. I'm also using this project to learn more about alpha-beta pruning and transposition tables in order to improve the performance of the solver.

-

I will be posting updates on my progress on my Twitter, so follow me there if you want to stay updated!

-
- -# Solving Connect Four - -## Premise - -Connect Four is a two-player game where the players take turns dropping colored discs into a 7x6 grid. The first player to get four of their discs in a row (either vertically, horizontally, or diagonally) is the winner. The game ends when the board is full or the players run out of discs, at which point the game is a draw [(1)](https://en.wikipedia.org/wiki/Connect_Four#Gameplay). It is also a strongly solved perfect information strategy game: first player has a winning strategy whatever his opponent plays [(2)](https://en.wikipedia.org/wiki/Connect_Four#Mathematical_solution). - -## Project details - -This projects aims to implement a Connect Four solver in Java. Then, it will be tested against some common test cases and eventually be used to play against a real opponent. - -I have previously done this project using a simple Minimax algorithm, but I have decided to explore a more complex algorithm in order to create a more robust solver. - -## Success criteria - -The solver should be able to solve a given board in a reasonable amount of time: with every move, the solver should be able to determine whether it is winning or losing. - -- If it is winning, it should return a winning move for itself. -- If it is losing, it should block the opponent from winning if there is a winning move available for the opponent. - -## Acknowledgements - -I would like to thank [@PascalPons](https://github.com/PascalPons) for his tutorial on [the perfect Connect Four Solver](http://blog.gamesolver.org/): it was a great help to understand how the algorithm should be implemented. Additionally, many thanks to [Nick Drohan](https://www.linkedin.com/in/nick-drohan-b8a75014/) and [Rishab Nayak](https://github.com/rishabnayak) for guiding me through this and providing assistance - you guys are amazing. diff --git a/Solver.java b/Solver.java index 13e0af1..9544353 100644 --- a/Solver.java +++ b/Solver.java @@ -1,134 +1,52 @@ public class Solver { - private static long nodeCount; - private static int columnOrder[] = new int[Position.WIDTH]; + private long nodes = 0; - private static TranspositionTable transpositionTable; + private int negamax(Position pos) { + nodes++; - /** Constructor */ - public Solver() { - nodeCount = 0; - transpositionTable = new TranspositionTable(8388593); - - reset(); - for (int i = 0; i < Position.WIDTH; i++) { - columnOrder[i] = Position.WIDTH / 2 + (1 - 2 * (i % 2)) * (i + 1) / 2; + if (pos.getMoves() == Position.WIDTH * Position.HEIGHT) { + return 0; } - } - - /** A position has - * a positive score if the current player has a winning move, - * a negative score if the opponent has a winning move, - * and a score of 0 if neither player has a winning move. - */ - private static int negamax(Position pos, int alpha, int beta) { - nodeCount++; - - long next = pos.possibleNonLosingMoves(); - if (next == 0) return -(Position.WIDTH * Position.HEIGHT - pos.getMoves()) / 2; - if (pos.getMoves() >= Position.WIDTH * Position.HEIGHT - 2) return 0; - int min = -(Position.WIDTH * Position.HEIGHT - 2 - pos.getMoves()) / 2; - if (alpha < min) { - alpha = min; - if(alpha >= beta) return alpha; - } - - int max = (Position.WIDTH * Position.HEIGHT - 1 - pos.getMoves()) / 2; - if (beta > max) { - beta = max; - if(alpha >= beta) return beta; - } - - MoveSorter moves = new MoveSorter(); - - for(int i = Position.WIDTH - 1; i >= 0; i--) { - long move; - if((move = next & Position.column_mask(columnOrder[i])) != 0) { - moves.add(move, pos.moveScore(move)); + for (int x = 0; x < Position.WIDTH; x++) { + if (pos.playable(x) && pos.winsByPlaying(x)) { + return (Position.WIDTH * Position.HEIGHT + 1 - pos.getMoves()) / 2; } } - long nextMove; - while ((nextMove = moves.getNext()) != 0) { - Position next_pos = new Position(pos); - next_pos.play(nextMove); - int score = -negamax(next_pos, -beta, -alpha); - if (score >= beta) return score; - if(score > alpha) alpha = score; - } - - transpositionTable.put(pos.getKey(), (byte) (alpha - Position.MIN_SCORE + 1)); - return alpha; - } - - public int solve(Position pos, boolean weak) { - if (pos.canWinNext()) { - return (Position.WIDTH * Position.HEIGHT + 1 - pos.getMoves()) / 2; - } - - int min = -(Position.WIDTH * Position.HEIGHT - pos.getMoves()) / 2; - int max = (Position.WIDTH * Position.HEIGHT + 1 - pos.getMoves()) / 2; - - if(weak) { - min = -1; - max = 1; - } - - // iteratively narrow the min-max exploration window - while(min < max) { - int med = min + (max - min) / 2; - if (med <= 0 && min/2 < med) { - med = min / 2; - } else if (med >= 0 && med/2 > med) { - med = med / 2; - } - - // use a null depth window - int r = negamax(pos, med, med + 1); - if(r <= med) { - max = r; - } else { - min = r; + int bestScore = -Position.WIDTH * Position.HEIGHT; + for (int x = 0; x < Position.WIDTH; x++) { + if (pos.playable(x)) { + Position next = new Position(pos); + next.play(x); + int score = -negamax(next); + if (score > bestScore) { + bestScore = score; + } } } - return min; + return bestScore; } - /** Get the best move for the current player to play. - * @param pos the current position - * @return the column number of the best move - */ - public int chooseMove(Position pos, boolean weak) { - // perform negamax on a single column - int best_move = -1; - int best_score = -(Position.WIDTH * Position.HEIGHT - pos.getMoves()) / 2; + public int solve(Position pos) { + nodes = 0; + return negamax(pos); + } - for (int i = 0; i < Position.WIDTH; i++) { - long move = pos.possibleNonLosingMoves() & Position.column_mask(i); - if (move != 0) { - Position next_pos = new Position(pos); - next_pos.play(move); - int score = -solve(next_pos, weak); + public long getNodes() { + return nodes; + } +} - if (score > best_score) { - best_score = score; - best_move = i; - } - } - } +class Timer { + private long start; - return best_move; - } - - /** Reset the solver */ - public void reset() { - nodeCount = 0; - transpositionTable.reset(); + public Timer() { + start = System.nanoTime(); } - - /** Get the number of nodes visited */ - public long getNodeCount() { - return nodeCount; + + public long elapsed() { + return System.nanoTime() - start; } } diff --git a/TranspositionTable.java b/TranspositionTable.java deleted file mode 100644 index 6abcc7f..0000000 --- a/TranspositionTable.java +++ /dev/null @@ -1,59 +0,0 @@ -public class TranspositionTable { - /** An entry in the transposition table. - * Each entry has a 56 bit key and an 8 bit score. - */ - private class Entry { - private long key; - private byte value; - - public Entry(long key, byte value) { - this.key = key; - this.value = value; - } - } - - private Entry[] table; - - /** Get the index of the entry with the given key. - * @param key The key of the entry. - * @return The index of the entry with the given key. - */ - private long index(long key) { - return key % table.length; - } - - /** Constructor */ - public TranspositionTable(long size) { - table = new Entry[(int) size]; - } - - /** Reset the transposition table. */ - public void reset() { - for (int i = 0; i < table.length; i++) { - table[i] = new Entry((long) 0, (byte) 0); - } - } - - /** Insert an entry into the transposition table. - * @param key The key of the entry. - * @param value The value of the entry. - */ - public void put(long key, byte value) { - int index = (int) index(key); - table[index].key = key; - table[index].value = value; - } - - /** Get the value of the entry with the given key. - * @param key The key of the entry. - * @return The value of the entry with the given key. - */ - public byte get(long key) { - int index = (int) index(key); - if (table[index].key == key) { - return table[index].value; - } else { - return 0; - } - } -} diff --git a/instructions.txt b/instructions.txt deleted file mode 100644 index b7cf6e3..0000000 --- a/instructions.txt +++ /dev/null @@ -1,20 +0,0 @@ -This is a demo of my final project - a Connect Four solver algorithm written in Java. - -It is based on the Minimax algorithm, but with a few modifications to make it more efficient across different board configurations: - -1. It uses alpha-beta pruning to prune the search tree as soon as we know that the score of the position is greater than beta. -2. It uses a bitmap encoding of positions to reduce significantly the computation time. -3. It uses a transposition table by caching the outcome of previous computation to avoid recomputation. - -When the algorithm is finished, it will print out the best score and the best move it found. It will also print out the time it took to compute the best move. - -If you find any issues or have any suggestions, please let me know. - -Thanks for reading! -—————————————————————————————————————————————————————————————————————————————————————————————————————————————————————— - -Inputs: -- a string of the current board configuration (for example, "1234567" to cover the bottom row). Note that as the algorithm has an exponential complexity, it is not suitable for predicting the best move for a position with less than 10 pieces. -- a boolean indicating whether the user wants to use a weak or strong heuristic: the program defaults to the strong heuristic. - -Output: The score of the current board configuration, as well as the best move to play for the current player. diff --git a/util/ClearConsole.java b/util/ClearConsole.java deleted file mode 100644 index 9e5aa6f..0000000 --- a/util/ClearConsole.java +++ /dev/null @@ -1,8 +0,0 @@ -package util; - -public class ClearConsole { - public static void clear() { - System.out.print("\033[H\033[2J"); // Esc + move cursor + clear screen - System.out.flush(); - } -} diff --git a/util/ConsoleColors.java b/util/ConsoleColors.java deleted file mode 100644 index 760e1cc..0000000 --- a/util/ConsoleColors.java +++ /dev/null @@ -1,76 +0,0 @@ -package util; - -public class ConsoleColors { - // Reset - public static final String RESET = "\033[0m"; // Text Reset - - // Regular colors - public static final String BLACK = "\033[0;30m"; // BLACK - public static final String RED = "\033[0;31m"; // RED - public static final String GREEN = "\033[0;32m"; // GREEN - public static final String YELLOW = "\033[0;33m"; // YELLOW - public static final String BLUE = "\033[0;34m"; // BLUE - public static final String PURPLE = "\033[0;35m"; // PURPLE - public static final String CYAN = "\033[0;36m"; // CYAN - public static final String WHITE = "\033[0;37m"; // WHITE - - // Bold - public static final String BLACK_BOLD = "\033[1;30m"; // BLACK - public static final String RED_BOLD = "\033[1;31m"; // RED - public static final String GREEN_BOLD = "\033[1;32m"; // GREEN - public static final String YELLOW_BOLD = "\033[1;33m"; // YELLOW - public static final String BLUE_BOLD = "\033[1;34m"; // BLUE - public static final String PURPLE_BOLD = "\033[1;35m"; // PURPLE - public static final String CYAN_BOLD = "\033[1;36m"; // CYAN - public static final String WHITE_BOLD = "\033[1;37m"; // WHITE - - // Underline - public static final String BLACK_UNDERLINED = "\033[4;30m"; // BLACK - public static final String RED_UNDERLINED = "\033[4;31m"; // RED - public static final String GREEN_UNDERLINED = "\033[4;32m"; // GREEN - public static final String YELLOW_UNDERLINED = "\033[4;33m"; // YELLOW - public static final String BLUE_UNDERLINED = "\033[4;34m"; // BLUE - public static final String PURPLE_UNDERLINED = "\033[4;35m"; // PURPLE - public static final String CYAN_UNDERLINED = "\033[4;36m"; // CYAN - public static final String WHITE_UNDERLINED = "\033[4;37m"; // WHITE - - // Background - public static final String BLACK_BACKGROUND = "\033[40m"; // BLACK - public static final String RED_BACKGROUND = "\033[41m"; // RED - public static final String GREEN_BACKGROUND = "\033[42m"; // GREEN - public static final String YELLOW_BACKGROUND = "\033[43m"; // YELLOW - public static final String BLUE_BACKGROUND = "\033[44m"; // BLUE - public static final String PURPLE_BACKGROUND = "\033[45m"; // PURPLE - public static final String CYAN_BACKGROUND = "\033[46m"; // CYAN - public static final String WHITE_BACKGROUND = "\033[47m"; // WHITE - - // High intensity - public static final String BLACK_BRIGHT = "\033[0;90m"; // BLACK - public static final String RED_BRIGHT = "\033[0;91m"; // RED - public static final String GREEN_BRIGHT = "\033[0;92m"; // GREEN - public static final String YELLOW_BRIGHT = "\033[0;93m"; // YELLOW - public static final String BLUE_BRIGHT = "\033[0;94m"; // BLUE - public static final String PURPLE_BRIGHT = "\033[0;95m"; // PURPLE - public static final String CYAN_BRIGHT = "\033[0;96m"; // CYAN - public static final String WHITE_BRIGHT = "\033[0;97m"; // WHITE - - // Bold and high intensity - public static final String BLACK_BOLD_BRIGHT = "\033[1;90m"; // BLACK - public static final String RED_BOLD_BRIGHT = "\033[1;91m"; // RED - public static final String GREEN_BOLD_BRIGHT = "\033[1;92m"; // GREEN - public static final String YELLOW_BOLD_BRIGHT = "\033[1;93m"; // YELLOW - public static final String BLUE_BOLD_BRIGHT = "\033[1;94m"; // BLUE - public static final String PURPLE_BOLD_BRIGHT = "\033[1;95m"; // PURPLE - public static final String CYAN_BOLD_BRIGHT = "\033[1;96m"; // CYAN - public static final String WHITE_BOLD_BRIGHT = "\033[1;97m"; // WHITE - - // High intensity backgrounds - public static final String BLACK_BACKGROUND_BRIGHT = "\033[0;100m"; // BLACK - public static final String RED_BACKGROUND_BRIGHT = "\033[0;101m"; // RED - public static final String GREEN_BACKGROUND_BRIGHT = "\033[0;102m"; // GREEN - public static final String YELLOW_BACKGROUND_BRIGHT = "\033[0;103m"; // YELLOW - public static final String BLUE_BACKGROUND_BRIGHT = "\033[0;104m"; // BLUE - public static final String PURPLE_BACKGROUND_BRIGHT = "\033[0;105m"; // PURPLE - public static final String CYAN_BACKGROUND_BRIGHT = "\033[0;106m"; // CYAN - public static final String WHITE_BACKGROUND_BRIGHT = "\033[0;107m"; // WHITE -} From e3ea8cc965af897bbc18729d0f0504bf42e3c2f9 Mon Sep 17 00:00:00 2001 From: Zeyu Yao Date: Tue, 2 May 2023 22:15:38 +0800 Subject: [PATCH 2/4] Add alpha-beta pruning and move exploration order --- Main.java | 4 +++- Solver.java | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/Main.java b/Main.java index 0f51874..6ab50a7 100644 --- a/Main.java +++ b/Main.java @@ -24,7 +24,7 @@ public static void main(String[] args) { Position P = new Position(); P.play(seq); - + int score = solver.solve(P); if (score != expectedScore) { @@ -32,6 +32,8 @@ public static void main(String[] args) { br.close(); return; } + + System.out.println(seq + " " + score); } br.close(); diff --git a/Solver.java b/Solver.java index 9544353..d3fc1da 100644 --- a/Solver.java +++ b/Solver.java @@ -1,7 +1,16 @@ public class Solver { private long nodes = 0; + private int[] columnOrder = new int[Position.WIDTH]; - private int negamax(Position pos) { + public Solver() { + nodes = 0; + for (int i = 0; i < Position.WIDTH; i++) { + columnOrder[i] = Position.WIDTH / 2 + (1 - 2 * (i % 2)) * (i + 1) / 2; + } + } + + private int negamax(Position pos, int alpha, int beta) { + assert alpha < beta; nodes++; if (pos.getMoves() == Position.WIDTH * Position.HEIGHT) { @@ -14,24 +23,34 @@ private int negamax(Position pos) { } } - int bestScore = -Position.WIDTH * Position.HEIGHT; + int max = (Position.WIDTH * Position.HEIGHT - 1 - pos.getMoves()) / 2; + if (beta > max) { + beta = max; + if (alpha >= beta) { + return beta; + } + } + for (int x = 0; x < Position.WIDTH; x++) { - if (pos.playable(x)) { + if (pos.playable(columnOrder[x])) { Position next = new Position(pos); - next.play(x); - int score = -negamax(next); - if (score > bestScore) { - bestScore = score; + next.play(columnOrder[x]); + int score = -negamax(next, -beta, -alpha); + if (score >= beta) { + return score; + } + if (score > alpha) { + alpha = score; } } } - return bestScore; + return alpha; } public int solve(Position pos) { nodes = 0; - return negamax(pos); + return negamax(pos, -Position.WIDTH * Position.HEIGHT / 2, Position.WIDTH * Position.HEIGHT / 2); } public long getNodes() { From 2c8bfddc4067b1c241f07f2b97b693af0eb6116b Mon Sep 17 00:00:00 2001 From: Zeyu Yao Date: Tue, 2 May 2023 22:32:51 +0800 Subject: [PATCH 3/4] Bitboard implementation --- Position.java | 88 ++++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/Position.java b/Position.java index ce562ad..5811fb6 100644 --- a/Position.java +++ b/Position.java @@ -2,39 +2,29 @@ public class Position { public static final int WIDTH = 7; public static final int HEIGHT = 6; - private int board[][] = new int[WIDTH][HEIGHT]; - private int tokens[] = new int[WIDTH]; - private int moves = 0; + private long current_position; + private long mask; + private int moves; public Position() { - for (int i = 0; i < WIDTH; i++) { - tokens[i] = 0; - for (int j = 0; j < HEIGHT; j++) { - board[i][j] = 0; - } - } - + current_position = 0; + mask = 0; moves = 0; } public Position(Position P) { - for (int i = 0; i < WIDTH; i++) { - tokens[i] = P.tokens[i]; - for (int j = 0; j < HEIGHT; j++) { - board[i][j] = P.board[i][j]; - } - } - + current_position = P.current_position; + mask = P.mask; moves = P.moves; } public boolean playable(int col) { - return (tokens[col] < HEIGHT); + return (mask & top_mask(col)) == 0; } public void play(int col) { - board[col][tokens[col]] = moves % 2 + 1; - tokens[col]++; + current_position ^= mask; + mask |= (mask + bottom_mask(col)); moves++; } @@ -50,34 +40,52 @@ public int play(String seq) { } public boolean winsByPlaying(int col) { - int player = (moves % 2) + 1; + long pos = current_position; + pos |= (mask + bottom_mask(col)) & column_mask(col); + return alignment(pos); + } - if (tokens[col] >= 3) { - if (board[col][tokens[col] - 1] == player && board[col][tokens[col] - 2] == player - && board[col][tokens[col] - 3] == player) { - return true; - } + public int getMoves() { + return moves; + } + + public long key() { + return current_position + mask; + } + + private boolean alignment(long pos) { + long m = pos & (pos >> (HEIGHT + 1)); + if ((m & (m >> (2 * (HEIGHT + 1)))) != 0) { + return true; } - for (int dy = -1; dy <= 1; dy++) { - int aligned = 0; - for (int dx = -1; dx <= 1; dx += 2) { - for (int x = col + dx, y = tokens[col] + dx * dy; x >= 0 && x < WIDTH && y >= 0 && y < HEIGHT - && board[x][y] == player; aligned++) { - x += dx; - y += dx * dy; - } - } + m = pos & (pos >> HEIGHT); + if ((m & (m >> (2 * HEIGHT))) != 0) { + return true; + } - if (aligned >= 3) { - return true; - } + m = pos & (pos >> (HEIGHT + 2)); + if ((m & (m >> (2 * (HEIGHT + 2)))) != 0) { + return true; + } + + m = pos & (pos >> 1); + if ((m & (m >> 2)) != 0) { + return true; } return false; } - public int getMoves() { - return moves; + private long top_mask(int col) { + return (1L << (HEIGHT - 1)) << col * (HEIGHT + 1); + } + + private long bottom_mask(int col) { + return 1L << col * (HEIGHT + 1); + } + + private long column_mask(int col) { + return ((1L << HEIGHT) - 1) << col * (HEIGHT + 1); } } From a0fe0771010c7d5188c3b5b745372af8ccf6fedc Mon Sep 17 00:00:00 2001 From: Zeyu Yao Date: Tue, 2 May 2023 23:10:55 +0800 Subject: [PATCH 4/4] Transposition table implementation --- Position.java | 2 ++ Solver.java | 16 +++++++++++++++- TranspositionTable.java | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 TranspositionTable.java diff --git a/Position.java b/Position.java index 5811fb6..44256d2 100644 --- a/Position.java +++ b/Position.java @@ -1,6 +1,8 @@ public class Position { public static final int WIDTH = 7; public static final int HEIGHT = 6; + public static final int MIN_SCORE = -(WIDTH * HEIGHT) / 2 + 3; + public static final int MAX_SCORE = (WIDTH * HEIGHT + 1) / 2 - 3; private long current_position; private long mask; diff --git a/Solver.java b/Solver.java index d3fc1da..38ba8dd 100644 --- a/Solver.java +++ b/Solver.java @@ -1,9 +1,11 @@ public class Solver { private long nodes = 0; private int[] columnOrder = new int[Position.WIDTH]; + private TranspositionTable transpositionTable = new TranspositionTable(8388593); public Solver() { - nodes = 0; + reset(); + for (int i = 0; i < Position.WIDTH; i++) { columnOrder[i] = Position.WIDTH / 2 + (1 - 2 * (i % 2)) * (i + 1) / 2; } @@ -24,6 +26,12 @@ private int negamax(Position pos, int alpha, int beta) { } int max = (Position.WIDTH * Position.HEIGHT - 1 - pos.getMoves()) / 2; + int val = transpositionTable.get(pos.key()); + + if (val != 0) { + max = val + Position.MIN_SCORE - 1; + } + if (beta > max) { beta = max; if (alpha >= beta) { @@ -45,6 +53,7 @@ private int negamax(Position pos, int alpha, int beta) { } } + transpositionTable.put(pos.key(), (byte) (alpha - Position.MIN_SCORE + 1)); return alpha; } @@ -56,6 +65,11 @@ public int solve(Position pos) { public long getNodes() { return nodes; } + + public void reset() { + nodes = 0; + transpositionTable.reset(); + } } class Timer { diff --git a/TranspositionTable.java b/TranspositionTable.java new file mode 100644 index 0000000..a76a25b --- /dev/null +++ b/TranspositionTable.java @@ -0,0 +1,41 @@ +import java.util.Arrays; + +public class TranspositionTable { + private static class Entry { + public long key; + public byte val; + + public Entry(long key, byte val) { + this.key = key; + this.val = val; + } + } + + private Entry[] entries; + + public TranspositionTable(int size) { + entries = new Entry[size]; + Arrays.fill(entries, new Entry(0, (byte) 0)); + } + + public void reset() { + Arrays.fill(entries, new Entry(0, (byte) 0)); + } + + public void put(long key, byte val) { + int i = index(key); + entries[i] = new Entry(key, val); + } + + public byte get(long key) { + int i = index(key); + if (entries[i].key == key) + return entries[i].val; + else + return 0; + } + + private int index(long key) { + return (int) (key % entries.length); + } +}