diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 776ed6f..4cccac7 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Any, Literal, Self import numpy as np @@ -230,3 +230,11 @@ class MultiAgentPolicyManager(BasePolicy): for k, v in out.items(): results[agent_id + "/" + k] = v return results + + # Need a train method that set all sub-policies to train mode. + # No need for a similar eval function, as eval internally uses the train function. + def train(self, mode: bool = True) -> Self: + """Set each internal policy in training mode.""" + for policy in self.policies.values(): + policy.train(mode) + return self