diff --git a/VSharp.Explorer/AISearcher.fs b/VSharp.Explorer/AISearcher.fs index 639f2f651..49a450b38 100644 --- a/VSharp.Explorer/AISearcher.fs +++ b/VSharp.Explorer/AISearcher.fs @@ -123,32 +123,28 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option Runner && stepsToPlay = stepsPlayed then + let toPredict = + match aiMode with + | TrainingSendEachStep + | TrainingSendModel -> + if stepsPlayed > 0u then + gameStateDelta + else + gameState.Value + | Runner -> gameState.Value + + let stateId = oracle.Predict toPredict + afterFirstAIPeek <- true + let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId) + lastCollectedStatistics <- statistics + stepsPlayed <- stepsPlayed + 1u + + match state with + | Some state -> Some state + | None -> + incorrectPredictedStateId <- true + oracle.Feedback(Feedback.IncorrectPredictedStateId stateId) None - else - let toPredict = - match aiMode with - | TrainingSendEachStep - | TrainingSendModel -> - if stepsPlayed > 0u then - gameStateDelta - else - gameState.Value - | Runner -> gameState.Value - - let stateId = oracle.Predict toPredict - - afterFirstAIPeek <- true - let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId) - lastCollectedStatistics <- statistics - stepsPlayed <- stepsPlayed + 1u - - match state with - | Some state -> Some state - | None -> - incorrectPredictedStateId <- true - oracle.Feedback(Feedback.IncorrectPredictedStateId stateId) - None static member updateGameState (delta: GameState) (gameState: Option) = match gameState with diff --git a/VSharp.ML.GameServer.Runner/Main.fs b/VSharp.ML.GameServer.Runner/Main.fs index 1b4b9a207..9f409b36f 100644 --- a/VSharp.ML.GameServer.Runner/Main.fs +++ b/VSharp.ML.GameServer.Runner/Main.fs @@ -239,6 +239,7 @@ let ws port outputDirectory (webSocket: WebSocket) (context: HttpContext) = GameOver( explorationResult.ActualCoverage, explorationResult.TestsCount, + explorationResult.StepsCount, explorationResult.ErrorsCount ) ) diff --git a/VSharp.ML.GameServer/Messages.fs b/VSharp.ML.GameServer/Messages.fs index 43c60c51b..7f13570f7 100644 --- a/VSharp.ML.GameServer/Messages.fs +++ b/VSharp.ML.GameServer/Messages.fs @@ -109,11 +109,13 @@ type GameOverMessageBody = interface IRawOutgoingMessageBody val ActualCoverage: uint val TestsCount: uint32 + val StepsCount: uint32 val ErrorsCount: uint32 - new(actualCoverage, testsCount, errorsCount) = + new(actualCoverage, testsCount, stepsCount, errorsCount) = { ActualCoverage = actualCoverage TestsCount = testsCount + StepsCount = stepsCount ErrorsCount = errorsCount } [] @@ -371,7 +373,7 @@ type IncorrectPredictedStateIdMessageBody = new(stateId) = { StateId = stateId } type OutgoingMessage = - | GameOver of uint * uint32 * uint32 + | GameOver of uint * uint32 * uint32 * uint32 | MoveReward of Reward | IncorrectPredictedStateId of uint | ReadyForNextStep of GameState @@ -397,8 +399,8 @@ let deserializeInputMessage (messageData: byte[]) = let serializeOutgoingMessage (message: OutgoingMessage) = match message with - | GameOver(actualCoverage, testsCount, errorsCount) -> - RawOutgoingMessage("GameOver", box (GameOverMessageBody(actualCoverage, testsCount, errorsCount))) + | GameOver(actualCoverage, testsCount, stepsCount, errorsCount) -> + RawOutgoingMessage("GameOver", box (GameOverMessageBody(actualCoverage, testsCount, stepsCount, errorsCount))) | MoveReward reward -> RawOutgoingMessage("MoveReward", reward) | IncorrectPredictedStateId stateId -> RawOutgoingMessage("IncorrectPredictedStateId", IncorrectPredictedStateIdMessageBody stateId)