performance progress tracking
This commit is contained in:
parent
f1ec0a08d9
commit
e66771dba9
16 changed files with 2816 additions and 121 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -1,7 +1,4 @@
|
||||||
/target
|
/target
|
||||||
/.idea
|
/.idea
|
||||||
/src/replace.py
|
|
||||||
/scripts
|
|
||||||
/Cargo.lock
|
/Cargo.lock
|
||||||
/src/benchmark/stockfish
|
progress_tracking/progress.xlsx
|
||||||
uci_log.txt
|
|
||||||
11
Cargo.toml
11
Cargo.toml
|
|
@ -5,3 +5,14 @@ authors = ["Moritz Eigenauer <moritz.eigenauer@gmail.com>"]
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
criterion = "0.7.0"
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "perft"
|
||||||
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "eval"
|
||||||
|
harness = false
|
||||||
15
benches/eval.rs
Normal file
15
benches/eval.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
use chess_engine::board::Board;
|
||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
use chess_engine::eval::basic::evaluate_board;
|
||||||
|
|
||||||
|
fn run_eval_benchmark(c: &mut Criterion) {
|
||||||
|
let board = Board::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1");
|
||||||
|
c.bench_function("standard_board_evaluation", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
evaluate_board(&board);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, run_eval_benchmark);
|
||||||
|
criterion_main!(benches);
|
||||||
36
benches/perft.rs
Normal file
36
benches/perft.rs
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
use chess_engine::board::Board;
|
||||||
|
use chess_engine::movegen::generate_pseudo_legal_moves;
|
||||||
|
use chess_engine::movegen::legal_check::is_other_king_attacked;
|
||||||
|
use chess_engine::r#move::MoveList;
|
||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
fn count_legal_moves_recursive(board: &mut Board, depth: u8) -> u64 {
|
||||||
|
if depth == 0 {
|
||||||
|
return 1_u64;
|
||||||
|
}
|
||||||
|
let mut list = MoveList::new();
|
||||||
|
generate_pseudo_legal_moves(&board, &mut list);
|
||||||
|
let mut leaf_nodes = 0_u64;
|
||||||
|
for mv in list.iter() {
|
||||||
|
let undo_info = board.make_move(*mv);
|
||||||
|
if !is_other_king_attacked(board) {
|
||||||
|
leaf_nodes += count_legal_moves_recursive(board, depth - 1);
|
||||||
|
}
|
||||||
|
board.undo_move(undo_info);
|
||||||
|
}
|
||||||
|
leaf_nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn run_perft_benchmark(c: &mut Criterion) {
|
||||||
|
let mut board = Board::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1");
|
||||||
|
|
||||||
|
c.bench_function("standard_perft5", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
assert_eq!(count_legal_moves_recursive(&mut board, 5), 4865609);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, run_perft_benchmark);
|
||||||
|
criterion_main!(benches);
|
||||||
1188
helper_scripts/STS1-STS15_LAN_v6.epd
Normal file
1188
helper_scripts/STS1-STS15_LAN_v6.epd
Normal file
File diff suppressed because it is too large
Load diff
78
helper_scripts/convert_epd.py
Normal file
78
helper_scripts/convert_epd.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
# converts the stockfish test suite epd file to a csv file containing just the fen and the best move
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
INPUT_FILE_PATH = "STS1-STS15_LAN_v6.epd"
|
||||||
|
OUTPUT_FILE_PATH = "../src/bin/stockfish_testsuite.csv"
|
||||||
|
|
||||||
|
def parse_line(line: str) -> str | None:
|
||||||
|
try:
|
||||||
|
parts = line.split(';')
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
main_part = parts[0]
|
||||||
|
other_parts = parts[1:]
|
||||||
|
|
||||||
|
bm_index = main_part.find(" bm ")
|
||||||
|
if bm_index == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fen = main_part[:bm_index].strip()
|
||||||
|
print(f"fen: '{fen}'")
|
||||||
|
fen += " 0 1"
|
||||||
|
|
||||||
|
lan_move = None
|
||||||
|
for part in other_parts:
|
||||||
|
part = part.strip()
|
||||||
|
if part.startswith('c9 "'):
|
||||||
|
content_start = len('c9 "')
|
||||||
|
content_end = part.rfind('"')
|
||||||
|
|
||||||
|
if content_end <= content_start:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = part[content_start:content_end].strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lan_move = content.split()[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
if lan_move is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return f"{fen},{lan_move}"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def convert_file(input_path: str, output_path: str):
|
||||||
|
try:
|
||||||
|
with open(input_path, 'r', encoding='utf-8') as infile, \
|
||||||
|
open(output_path, 'w', encoding='utf-8') as outfile:
|
||||||
|
|
||||||
|
for line in infile:
|
||||||
|
line_content = line.strip()
|
||||||
|
|
||||||
|
if not line_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output_line = parse_line(line_content)
|
||||||
|
|
||||||
|
if output_line:
|
||||||
|
outfile.write(output_line + '\n')
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Input file '{input_path}' not found.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
except IOError as e:
|
||||||
|
print(f"Error reading/writing file: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert_file(INPUT_FILE_PATH, OUTPUT_FILE_PATH)
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
a2a3 380
|
|
||||||
b2b3 420
|
|
||||||
c2c3 420
|
|
||||||
d2d3 539
|
|
||||||
e2e3 599
|
|
||||||
f2f3 380
|
|
||||||
g2g3 420
|
|
||||||
h2h3 380
|
|
||||||
a2a4 420
|
|
||||||
b2b4 421
|
|
||||||
c2c4 441
|
|
||||||
d2d4 560
|
|
||||||
e2e4 600
|
|
||||||
f2f4 401
|
|
||||||
g2g4 421
|
|
||||||
h2h4 420
|
|
||||||
b1a3 400
|
|
||||||
b1c3 440
|
|
||||||
g1f3 440
|
|
||||||
g1h3 400
|
|
||||||
|
|
||||||
Total8902
|
|
||||||
215
progress_tracking/collect_benchmarks.py
Normal file
215
progress_tracking/collect_benchmarks.py
Normal file
|
|
@ -0,0 +1,215 @@
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
import openpyxl
|
||||||
|
import datetime
|
||||||
|
from openpyxl.styles import Font
|
||||||
|
from openpyxl.formatting.rule import ColorScaleRule
|
||||||
|
from openpyxl.utils import get_column_letter
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# Adjust these paths if your benchmark names are different!
|
||||||
|
PERFT_JSON_PATH = "C:/Users/Moritz/RustroverProjects/ChessEngine/target/criterion/standard_perft5/new/estimates.json"
|
||||||
|
EVAL_JSON_PATH = "C:/Users/Moritz/RustroverProjects/ChessEngine/target/criterion/standard_board_evaluation/new/estimates.json"
|
||||||
|
EXCEL_FILE = "C:/Users/Moritz/RustroverProjects/ChessEngine/progress_tracking/progress.xlsx"
|
||||||
|
HEADERS = ["TIMESTAMP", "COMMIT", "MESSAGE", "PERFT (ms)", "EVAL (ps)", "SUITE (%)"]
|
||||||
|
|
||||||
|
COLUMN_WIDTHS = {
|
||||||
|
'A': 20, # Timestamp
|
||||||
|
'B': 12, # Commit
|
||||||
|
'C': 50, # Message
|
||||||
|
'D': 14, # Perft
|
||||||
|
'E': 14, # Eval
|
||||||
|
'F': 14 # Suite
|
||||||
|
}
|
||||||
|
|
||||||
|
# NEW: Define fonts
|
||||||
|
DEFAULT_FONT = Font(name='Consolas', size=11)
|
||||||
|
HEADER_FONT = Font(name='Consolas', size=11, bold=True)
|
||||||
|
# ---------------------
|
||||||
|
|
||||||
|
def run_command(command):
|
||||||
|
"""Executes a shell command and returns its output."""
|
||||||
|
print(f"Running: {' '.join(command)}")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True, check=True, encoding='utf-8')
|
||||||
|
return result
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Error running command: {e}")
|
||||||
|
print("STDOUT:", e.stdout)
|
||||||
|
print("STDERR:", e.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
def get_criterion_result(json_path):
|
||||||
|
"""Reads the result from a Criterion JSON file."""
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
# Returns the 'point_estimate' of the mean in nanoseconds
|
||||||
|
return data['mean']['point_estimate']
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: JSON file not found: {json_path}")
|
||||||
|
print("Make sure 'cargo bench' was successful and the paths are correct.")
|
||||||
|
exit(1)
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
print(f"Error: Unexpected format in {json_path}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
def get_git_info():
|
||||||
|
"""Checks if the git working directory is dirty. Returns (hash, message)"""
|
||||||
|
status_result = run_command(["git", "status", "--porcelain"])
|
||||||
|
|
||||||
|
if status_result.stdout.strip():
|
||||||
|
print("Uncommitted changes detected. Using 'local' as commit ID.")
|
||||||
|
return ("local", "Uncommitted changes")
|
||||||
|
else:
|
||||||
|
hash_result = run_command(["git", "rev-parse", "--short", "HEAD"])
|
||||||
|
msg_result = run_command(["git", "log", "-1", "--pretty=%s"])
|
||||||
|
return (hash_result.stdout.strip(), msg_result.stdout.strip())
|
||||||
|
|
||||||
|
def apply_styles_and_formats(ws, row_index, is_header=False):
|
||||||
|
"""Applies fonts and number formats to a specific row."""
|
||||||
|
font = HEADER_FONT if is_header else DEFAULT_FONT
|
||||||
|
|
||||||
|
# Get column indices
|
||||||
|
try:
|
||||||
|
perft_col_idx = HEADERS.index('PERFT (ms)') + 1
|
||||||
|
eval_col_idx = HEADERS.index('EVAL (ps)') + 1
|
||||||
|
suite_col_idx = HEADERS.index('SUITE (%)') + 1
|
||||||
|
except ValueError:
|
||||||
|
print("Error: Could not find all headers. Check HEADERS config.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for cell in ws[row_index]:
|
||||||
|
cell.font = font
|
||||||
|
|
||||||
|
# Apply number formats only to data rows
|
||||||
|
if not is_header:
|
||||||
|
if cell.column == perft_col_idx or cell.column == eval_col_idx or cell.column == suite_col_idx:
|
||||||
|
cell.number_format = '0.00'
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. Run benchmarks and suite
|
||||||
|
print("Starting benchmarks... (This may take a few minutes)")
|
||||||
|
run_command(["cargo", "bench", "--bench", "perft"])
|
||||||
|
run_command(["cargo", "bench", "--bench", "eval"])
|
||||||
|
|
||||||
|
print("Starting suite test...")
|
||||||
|
suite_result = run_command(["cargo", "run", "--bin", "suite", "--release"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The suite_score is still a raw float, e.g., 95.5
|
||||||
|
suite_score = float(suite_result.stdout.strip())
|
||||||
|
except ValueError:
|
||||||
|
print(f"Error: Could not convert suite output to a number.")
|
||||||
|
print(f"Received: '{suite_result.stdout}'")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Collecting results...")
|
||||||
|
|
||||||
|
# 2. Get Git info and Timestamp
|
||||||
|
(commit_hash, commit_message) = get_git_info()
|
||||||
|
timestamp = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
|
||||||
|
|
||||||
|
# 3. Read benchmark results
|
||||||
|
# Convert from nanoseconds to milliseconds
|
||||||
|
perft_ms = get_criterion_result(PERFT_JSON_PATH) / 1_000_000.0
|
||||||
|
# Convert from nanoseconds to picoseconds
|
||||||
|
eval_ps = get_criterion_result(EVAL_JSON_PATH) * 1000.0
|
||||||
|
|
||||||
|
# 4. Write data to the Excel file
|
||||||
|
file_path = pathlib.Path(EXCEL_FILE)
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
wb = openpyxl.load_workbook(EXCEL_FILE)
|
||||||
|
ws = wb.active
|
||||||
|
# Check if cell A1 has the correct header. If not, the file is empty/corrupt
|
||||||
|
if ws.cell(row=1, column=1).value != HEADERS[0]:
|
||||||
|
print("File was empty or corrupt. Re-creating headers.")
|
||||||
|
ws.append(HEADERS)
|
||||||
|
apply_styles_and_formats(ws, 1, is_header=True)
|
||||||
|
else:
|
||||||
|
wb = openpyxl.Workbook()
|
||||||
|
ws = wb.active
|
||||||
|
ws.title = "Progress"
|
||||||
|
ws.append(HEADERS)
|
||||||
|
apply_styles_and_formats(ws, 1, is_header=True) # Apply header style
|
||||||
|
print(f"New file '{EXCEL_FILE}' created.")
|
||||||
|
|
||||||
|
# --- Set Column Widths ---
|
||||||
|
# !! This was the fix: Removed the "if" check and adjusted units.
|
||||||
|
for col_letter, width in COLUMN_WIDTHS.items():
|
||||||
|
ws.column_dimensions[col_letter].width = width
|
||||||
|
|
||||||
|
# --- Overwrite Logic ---
|
||||||
|
if commit_hash == "local" and ws.max_row > 1:
|
||||||
|
try:
|
||||||
|
commit_col_index = HEADERS.index("COMMIT") + 1
|
||||||
|
except ValueError:
|
||||||
|
print("Error: 'COMMIT' column not found in headers.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
last_row_commit_val = ws.cell(row=ws.max_row, column=commit_col_index).value
|
||||||
|
|
||||||
|
if last_row_commit_val == "local":
|
||||||
|
ws.delete_rows(ws.max_row)
|
||||||
|
print("Overwriting previous 'local' entry.")
|
||||||
|
|
||||||
|
# Append the new row of data (using ms values)
|
||||||
|
new_row = [timestamp, commit_hash, commit_message, perft_ms, eval_ps, suite_score]
|
||||||
|
ws.append(new_row)
|
||||||
|
|
||||||
|
# Apply default font and number formats to the newly added row
|
||||||
|
apply_styles_and_formats(ws, ws.max_row, is_header=False)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Add/Update Conditional Formatting ---
|
||||||
|
perf_rule = ColorScaleRule(
|
||||||
|
start_type='min', start_color='63BE7B', # Green (Low = Fast = Good)
|
||||||
|
mid_type='percentile', mid_value=50, mid_color='FFEB84',
|
||||||
|
end_type='max', end_color='F8696B' # Red (High = Slow = Bad)
|
||||||
|
)
|
||||||
|
suite_rule = ColorScaleRule(
|
||||||
|
start_type='min', start_color='F8696B', # Red (Low = Bad)
|
||||||
|
mid_type='percentile', mid_value=50, mid_color='FFEB84',
|
||||||
|
end_type='max', end_color='63BE7B' # Green (High = Good)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
perft_col_letter = get_column_letter(HEADERS.index('PERFT (ms)') + 1)
|
||||||
|
# Note: This had a typo in your original file 'EVAL (fs)', I assume you meant 'EVAL (ps)'
|
||||||
|
eval_col_letter = get_column_letter(HEADERS.index('EVAL (ps)') + 1)
|
||||||
|
suite_col_letter = get_column_letter(HEADERS.index('SUITE (%)') + 1)
|
||||||
|
|
||||||
|
max_excel_row = 1048576 # Standard for .xlsx
|
||||||
|
ws.conditional_formatting.add(f'{perft_col_letter}2:{perft_col_letter}{max_excel_row}', perf_rule)
|
||||||
|
ws.conditional_formatting.add(f'{eval_col_letter}2:{eval_col_letter}{max_excel_row}', perf_rule)
|
||||||
|
ws.conditional_formatting.add(f'{suite_col_letter}2:{suite_col_letter}{max_excel_row}', suite_rule)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print("Warning: Could not find performance columns in headers. Skipping color formatting.")
|
||||||
|
# Print which headers are problematic
|
||||||
|
for col in ['PERFT (ms)', 'EVAL (ps)', 'SUITE (%)']:
|
||||||
|
if col not in HEADERS:
|
||||||
|
print(f"Header '{col}' is missing or misspelled in HEADERS list.")
|
||||||
|
|
||||||
|
|
||||||
|
# 5. Save the file
|
||||||
|
try:
|
||||||
|
wb.save(EXCEL_FILE)
|
||||||
|
except PermissionError:
|
||||||
|
print(f"Error: Could not save '{EXCEL_FILE}'.")
|
||||||
|
print("Please make sure the file is not open in Excel.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Success! Results saved to '{EXCEL_FILE}'.")
|
||||||
|
print(f" TIMESTAMP: {timestamp}")
|
||||||
|
print(f" COMMIT: {commit_hash}")
|
||||||
|
print(f" MESSAGE: {commit_message}")
|
||||||
|
print(f" PERFT: {perft_ms:.2f} ms")
|
||||||
|
print(f" EVAL: {eval_ps:.2f} ps")
|
||||||
|
print(f" SUITE: {suite_score:.2f} %")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1188
src/bin/stockfish_testsuite.csv
Normal file
1188
src/bin/stockfish_testsuite.csv
Normal file
File diff suppressed because it is too large
Load diff
63
src/bin/suite.rs
Normal file
63
src/bin/suite.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{self, BufRead};
|
||||||
|
use chess_engine::engine::Engine;
|
||||||
|
use std::time::{Instant, Duration};
|
||||||
|
// EACH TEST CAN ONLY TAKE ONE SECOND MAX TO KEEP RESULTS COMPARABLE
|
||||||
|
|
||||||
|
fn load_csv(path: &str) -> io::Result<Vec<Vec<String>>> {
|
||||||
|
let file = File::open(path)?;
|
||||||
|
let reader = io::BufReader::new(file);
|
||||||
|
|
||||||
|
let mut rows = Vec::new();
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
let cols = line
|
||||||
|
.split(',')
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
|
||||||
|
rows.push(cols);
|
||||||
|
}
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let mut total_tests: f32 = 0.0;
|
||||||
|
let mut correct_tests: f32 = 0.0;
|
||||||
|
let sts = load_csv("C:/Users/Moritz/RustroverProjects/ChessEngine/src/bin/stockfish_testsuite.csv").unwrap();
|
||||||
|
let mut engine = Engine::new("Yakari".to_string(), "EiSiMo".to_string());
|
||||||
|
|
||||||
|
// Set the time limit to 1 second
|
||||||
|
let time_limit = Duration::from_secs(1);
|
||||||
|
|
||||||
|
for test in &sts {
|
||||||
|
let fen = &test[0];
|
||||||
|
let bm = &test[1];
|
||||||
|
|
||||||
|
engine.setpos_fen(fen);
|
||||||
|
|
||||||
|
// Record start time
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let result = engine.search(4);
|
||||||
|
|
||||||
|
// Calculate duration
|
||||||
|
let duration = start_time.elapsed();
|
||||||
|
|
||||||
|
// Check if the test exceeded the time limit
|
||||||
|
if duration > time_limit {
|
||||||
|
panic!(
|
||||||
|
"Test exceeded 1 second limit: {:?} for FEN: {}",
|
||||||
|
duration, fen
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
total_tests += 1.0;
|
||||||
|
if result == *bm {
|
||||||
|
correct_tests += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{}", correct_tests / (total_tests / 100.0));
|
||||||
|
}
|
||||||
|
|
@ -32,14 +32,14 @@ impl Engine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn search(&mut self, depth: u8) {
|
pub fn search(&mut self, depth: u8) -> String {
|
||||||
let (opt_move, _score) = minimax(&mut self.board, depth);
|
let (opt_move, _score) = minimax(&mut self.board, depth, 0);
|
||||||
|
|
||||||
if let Some(mv) = opt_move {
|
if let Some(mv) = opt_move {
|
||||||
println!("bestmove {}", mv);
|
mv.to_algebraic()
|
||||||
} else {
|
} else {
|
||||||
// UCI format for no legal moves (checkmate/stalemate)
|
// UCI format for no legal moves (checkmate/stalemate)
|
||||||
println!("bestmove null");
|
"null".to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -7,12 +7,12 @@ pub fn evaluate_board(board: &Board) -> i32 {
|
||||||
score += board.pieces[PieceType::Bishop as usize][Color::White as usize].count_ones() as i32 * 300;
|
score += board.pieces[PieceType::Bishop as usize][Color::White as usize].count_ones() as i32 * 300;
|
||||||
score += board.pieces[PieceType::Rook as usize][Color::White as usize].count_ones() as i32 * 500;
|
score += board.pieces[PieceType::Rook as usize][Color::White as usize].count_ones() as i32 * 500;
|
||||||
score += board.pieces[PieceType::Queen as usize][Color::White as usize].count_ones() as i32 * 900;
|
score += board.pieces[PieceType::Queen as usize][Color::White as usize].count_ones() as i32 * 900;
|
||||||
score += board.pieces[PieceType::King as usize][Color::White as usize].count_ones() as i32 * 10000;
|
|
||||||
score -= board.pieces[PieceType::Pawn as usize][Color::Black as usize].count_ones() as i32 * 100;
|
score -= board.pieces[PieceType::Pawn as usize][Color::Black as usize].count_ones() as i32 * 100;
|
||||||
score -= board.pieces[PieceType::Knight as usize][Color::Black as usize].count_ones() as i32 * 300;
|
score -= board.pieces[PieceType::Knight as usize][Color::Black as usize].count_ones() as i32 * 300;
|
||||||
score -= board.pieces[PieceType::Bishop as usize][Color::Black as usize].count_ones() as i32 * 300;
|
score -= board.pieces[PieceType::Bishop as usize][Color::Black as usize].count_ones() as i32 * 300;
|
||||||
score -= board.pieces[PieceType::Rook as usize][Color::Black as usize].count_ones() as i32 * 500;
|
score -= board.pieces[PieceType::Rook as usize][Color::Black as usize].count_ones() as i32 * 500;
|
||||||
score -= board.pieces[PieceType::Queen as usize][Color::Black as usize].count_ones() as i32 * 900;
|
score -= board.pieces[PieceType::Queen as usize][Color::Black as usize].count_ones() as i32 * 900;
|
||||||
score -= board.pieces[PieceType::King as usize][Color::Black as usize].count_ones() as i32 * 10000;
|
|
||||||
score
|
score
|
||||||
}
|
}
|
||||||
|
|
@ -4,6 +4,8 @@ use crate::movegen::generate_pseudo_legal_moves;
|
||||||
use crate::movegen::legal_check::is_other_king_attacked;
|
use crate::movegen::legal_check::is_other_king_attacked;
|
||||||
use crate::r#move::{Move, MoveList};
|
use crate::r#move::{Move, MoveList};
|
||||||
|
|
||||||
|
// A score high enough to be > any material eval, but low enough to not overflow when adding ply
|
||||||
|
const MATE_SCORE: i32 = 1_000_000;
|
||||||
|
|
||||||
fn evaluate_board_relative(board: &Board) -> i32 {
|
fn evaluate_board_relative(board: &Board) -> i32 {
|
||||||
let static_eval = evaluate_board(board);
|
let static_eval = evaluate_board(board);
|
||||||
|
|
@ -13,7 +15,7 @@ fn evaluate_board_relative(board: &Board) -> i32 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn minimax(board: &mut Board, depth: u8) -> (Option<Move>, i32) {
|
pub fn minimax(board: &mut Board, depth: u8, ply: u8) -> (Option<Move>, i32) {
|
||||||
if depth == 0 {
|
if depth == 0 {
|
||||||
return (None, evaluate_board_relative(board));
|
return (None, evaluate_board_relative(board));
|
||||||
}
|
}
|
||||||
|
|
@ -21,7 +23,7 @@ pub fn minimax(board: &mut Board, depth: u8) -> (Option<Move>, i32) {
|
||||||
let mut list = MoveList::new();
|
let mut list = MoveList::new();
|
||||||
generate_pseudo_legal_moves(board, &mut list);
|
generate_pseudo_legal_moves(board, &mut list);
|
||||||
let mut best_move: Option<Move> = None;
|
let mut best_move: Option<Move> = None;
|
||||||
let mut best_score: i32 = -i32::MAX;
|
let mut best_score: i32 = -i32::MAX; // Start with the worst possible score
|
||||||
let mut legal_moves_found = false;
|
let mut legal_moves_found = false;
|
||||||
|
|
||||||
for mv in list.iter() {
|
for mv in list.iter() {
|
||||||
|
|
@ -32,7 +34,9 @@ pub fn minimax(board: &mut Board, depth: u8) -> (Option<Move>, i32) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
legal_moves_found = true;
|
legal_moves_found = true;
|
||||||
let (_, score) = minimax(board, depth - 1);
|
|
||||||
|
// Recursive call, incrementing ply
|
||||||
|
let (_, score) = minimax(board, depth - 1, ply + 1);
|
||||||
let current_score = -score;
|
let current_score = -score;
|
||||||
|
|
||||||
if current_score > best_score {
|
if current_score > best_score {
|
||||||
|
|
@ -45,8 +49,12 @@ pub fn minimax(board: &mut Board, depth: u8) -> (Option<Move>, i32) {
|
||||||
|
|
||||||
if !legal_moves_found {
|
if !legal_moves_found {
|
||||||
if is_other_king_attacked(board) {
|
if is_other_king_attacked(board) {
|
||||||
return (None, -i32::MAX);
|
// Checkmate
|
||||||
|
// The score is *less* negative the *longer* it takes to be mated (higher ply)
|
||||||
|
// This translates to a *higher* score for the winner for a *faster* mate
|
||||||
|
return (None, -MATE_SCORE + (ply as i32));
|
||||||
} else {
|
} else {
|
||||||
|
// Stalemate
|
||||||
return (None, 0);
|
return (None, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,9 @@ pub fn uci_mainloop(engine: &mut Engine) {
|
||||||
let fen = tokens[2..].join(" ");
|
let fen = tokens[2..].join(" ");
|
||||||
engine.setpos_fen(&fen);
|
engine.setpos_fen(&fen);
|
||||||
} else if tokens[1] == "startpos" {
|
} else if tokens[1] == "startpos" {
|
||||||
// Check explicitly for the "moves" keyword
|
|
||||||
if tokens.len() > 2 && tokens[2] == "moves" {
|
if tokens.len() > 2 && tokens[2] == "moves" {
|
||||||
// Command: "position startpos moves e2e4 e7e5 ..."
|
|
||||||
// Pass only the tokens *after* "moves"
|
|
||||||
engine.setpos_startpos(&tokens[3..]);
|
engine.setpos_startpos(&tokens[3..]);
|
||||||
} else {
|
} else {
|
||||||
// Command: "position startpos"
|
|
||||||
// Pass an empty slice
|
|
||||||
engine.setpos_startpos(&[]);
|
engine.setpos_startpos(&[]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -47,7 +42,7 @@ pub fn uci_mainloop(engine: &mut Engine) {
|
||||||
}
|
}
|
||||||
"go" => {
|
"go" => {
|
||||||
// TODO add a lot functionality
|
// TODO add a lot functionality
|
||||||
engine.search(5);
|
println!("{}", engine.search(6));
|
||||||
}
|
}
|
||||||
"stop" => {
|
"stop" => {
|
||||||
// TODO stop search as soon as possible
|
// TODO stop search as soon as possible
|
||||||
|
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
use chess_engine::board::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_standard() {
|
|
||||||
let fen_standard = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_standard).to_fen(), fen_standard);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_kiwipete() {
|
|
||||||
let fen_kiwipete = "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_kiwipete).to_fen(), fen_kiwipete);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_en_passant() {
|
|
||||||
let fen_en_passant = "rnbqkbnr/pppppp1p/8/8/p7/4P3/PPPP1PPP/RNBQKBNR w KQkq e3 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_en_passant).to_fen(), fen_en_passant);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_castle() {
|
|
||||||
let fen_castle = "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2QK2R b - - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_castle).to_fen(), fen_castle);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_just_kings() {
|
|
||||||
let fen_just_kings = "8/k7/8/8/8/8/7K/8 w - - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_just_kings).to_fen(), fen_just_kings);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_high_move_values() {
|
|
||||||
let fen_high_move_values = "8/P1k5/K7/8/8/8/8/8 w - - 0 78";
|
|
||||||
assert_eq!(Board::from_fen(fen_high_move_values).to_fen(), fen_high_move_values);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_empty_count1() {
|
|
||||||
let fen_empty_count1 = "1n6/8/8/8/8/8/8/8 w - - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_empty_count1).to_fen(), fen_empty_count1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_fen_roundtrip_empty_count2() {
|
|
||||||
let fen_empty_count2 = "6n1/8/8/8/8/8/8/8 w - - 0 1";
|
|
||||||
assert_eq!(Board::from_fen(fen_empty_count2).to_fen(), fen_empty_count2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_board_fen_state() {
|
|
||||||
let fen_standard = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
|
|
||||||
let board = Board::from_fen(fen_standard);
|
|
||||||
assert_eq!(board.pieces[PieceType::Pawn as usize][Color::White as usize], 65280);
|
|
||||||
assert_eq!(board.pieces[PieceType::Pawn as usize][Color::Black as usize], 71776119061217280);
|
|
||||||
assert_eq!(board.pieces[PieceType::Knight as usize][Color::White as usize], 66);
|
|
||||||
assert_eq!(board.pieces[PieceType::Knight as usize][Color::Black as usize], 4755801206503243776);
|
|
||||||
assert_eq!(board.pieces[PieceType::Bishop as usize][Color::White as usize], 36);
|
|
||||||
assert_eq!(board.pieces[PieceType::Bishop as usize][Color::Black as usize], 2594073385365405696);
|
|
||||||
assert_eq!(board.pieces[PieceType::Rook as usize][Color::White as usize], 129);
|
|
||||||
assert_eq!(board.pieces[PieceType::Rook as usize][Color::Black as usize], 9295429630892703744);
|
|
||||||
assert_eq!(board.pieces[PieceType::Queen as usize][Color::White as usize], 8);
|
|
||||||
assert_eq!(board.pieces[PieceType::Queen as usize][Color::Black as usize], 576460752303423488);
|
|
||||||
assert_eq!(board.pieces[PieceType::King as usize][Color::White as usize], 16);
|
|
||||||
assert_eq!(board.pieces[PieceType::King as usize][Color::Black as usize], 1152921504606846976);
|
|
||||||
|
|
||||||
assert_eq!(board.occupied[0], 65535);
|
|
||||||
assert_eq!(board.occupied[1], 18446462598732840960);
|
|
||||||
assert_eq!(board.all_occupied, 18446462598732906495);
|
|
||||||
|
|
||||||
assert_eq!(board.castling_rights, 15);
|
|
||||||
assert_eq!(board.en_passant_target, None);
|
|
||||||
assert_eq!(board.halfmove_clock, 0);
|
|
||||||
assert_eq!(board.fullmove_number, 1);
|
|
||||||
}
|
|
||||||
|
|
@ -27,8 +27,7 @@ fn count_legal_moves_recursive(board: &mut Board, depth: u8) -> u64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn perft() {
|
fn test_perft() {
|
||||||
// TalkChess PERFT Tests (by Martin Sedlak)
|
|
||||||
// Illegal ep move #1
|
// Illegal ep move #1
|
||||||
let mut board = Board::from_fen("3k4/3p4/8/K1P4r/8/8/8/8 b - - 0 1");
|
let mut board = Board::from_fen("3k4/3p4/8/K1P4r/8/8/8/8 b - - 0 1");
|
||||||
assert_eq!(count_legal_moves_recursive(&mut board, 6), 1134888, "Illegal ep move #1");
|
assert_eq!(count_legal_moves_recursive(&mut board, 6), 1134888, "Illegal ep move #1");
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue