diff --git a/lc0/src/chess/board.cc b/lc0/src/chess/board.cc index 60a1a0783..ad0cbc71b 100644 --- a/lc0/src/chess/board.cc +++ b/lc0/src/chess/board.cc @@ -690,14 +690,45 @@ void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply, throw Exception("Bad fen string: " + fen + " wrong en passant rank"); pawns_.set((square.row() == 2) ? 0 : 7, square.col()); } - if (who_to_move == "b" || who_to_move == "B") { Mirror(); } + + // if the fen states that we are on turn 1 and we are not in the default starting position, + // then the fen most likely specify the wrong move count (or it could be after just 1 ply) + // we modify the turn count so we later know we can add fake history to the net + if(total_moves < 2 && kStartingFen != fen) + total_moves = 5; + if (no_capture_ply) *no_capture_ply = no_capture_halfmoves; if (moves) *moves = total_moves; } +void ChessBoard::UndoEnPassantFromFen() const{ + for(int col = 0; col < 8; col++){ + // the opponent did a pawn move in turn before fen + if(pawns_.get(7,col)){ + // clear en passant flag + pawns_.reset(7,col); + //move piece back + pawns_.reset(4,col); + their_pieces_.reset(4,col); + pawns_.set(6,col); + their_pieces_.set(6,col); + } + // we did a pawn move in turn before fen + if(pawns_.get(0,col)){ + // clear en passant flag + pawns_.reset(0,col); + //move piece back + pawns_.reset(3,col); + our_pieces_.reset(3,col); + pawns_.set(1,col); + our_pieces_.set(1,col); + } + } +} + bool ChessBoard::HasMatingMaterial() const { if (!rooks_.empty() || !pawns_.empty()) { return true; diff --git a/lc0/src/chess/board.h b/lc0/src/chess/board.h index 3a37b4ab2..6f33ddef9 100644 --- a/lc0/src/chess/board.h +++ b/lc0/src/chess/board.h @@ -38,6 +38,8 @@ class ChessBoard { // the game. void SetFromFen(const std::string& fen, int* no_capture_ply = nullptr, int* moves = nullptr); + // if the game is started from FEN with en passant into, undo the move for fake history + void UndoEnPassantFromFen() const; // Nullifies the whole structure. void Clear(); // Swaps black and white pieces and mirrors them relative to the @@ -143,9 +145,9 @@ class ChessBoard { private: // All white pieces. - BitBoard our_pieces_; + mutable BitBoard our_pieces_; // All black pieces. - BitBoard their_pieces_; + mutable BitBoard their_pieces_; // Rooks and queens. BitBoard rooks_; // Bishops and queens; @@ -155,7 +157,7 @@ class ChessBoard { // corresponding white pawn on rank 4 can be taken en passant. Rank 8 is the // same for black pawns. Those "fake" pawns are not present in white_ and // black_ bitboards. - BitBoard pawns_; + mutable BitBoard pawns_; BoardSquare our_king_; BoardSquare their_king_; Castlings castlings_; @@ -169,4 +171,4 @@ struct MoveExecution { bool reset_50_moves; }; -} // namespace lczero \ No newline at end of file +} // namespace lczero diff --git a/lc0/src/neural/encoder.cc b/lc0/src/neural/encoder.cc index 5b81beae6..021c2de0d 100644 --- a/lc0/src/neural/encoder.cc +++ b/lc0/src/neural/encoder.cc @@ -47,13 +47,29 @@ InputPlanes EncodePositionForNN(const PositionHistory& history, bool flip = false; int history_idx = history.GetLength() - 1; + bool fake_history_plane = false; for (int i = 0; i < std::min(history_planes, kMoveHistory); ++i, flip = !flip, --history_idx) { - if (history_idx < 0) break; + if (history_idx < 0) { + // fill up the history window with the first history + // these planes was never inserted flipped into history.. + // so keep them the same as last real plane by inverting flip again + history_idx=0; + flip = !flip; + fake_history_plane = true; + } + const Position& position = history.GetPositionAt(history_idx); const ChessBoard& board = flip ? position.GetThemBoard() : position.GetBoard(); + + //if(fake_history_plane){ + // if the starting fen contained en passant info, then we know the previous move + // create a board with the pawn moved back + //board.UndoEnPassantFromFen(); + //} + const int base = i * kPlanesPerBoard; result[base + 0].mask = (board.ours() * board.pawns()).as_int(); result[base + 1].mask = (board.our_knights()).as_int(); @@ -70,10 +86,10 @@ InputPlanes EncodePositionForNN(const PositionHistory& history, result[base + 11].mask = (board.their_king()).as_int(); const int repetitions = position.GetRepetitions(); - if (repetitions >= 1) result[base + 12].SetAll(); + if (!fake_history_plane && repetitions >= 1) result[base + 12].SetAll(); } return result; } -} // namespace lczero \ No newline at end of file +} // namespace lczero