diff --git a/README.md b/README.md index f6d79c5..41935f1 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) +- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index aa24897..aa17040 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -99,6 +99,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.DiscreteCQLPolicy + :members: + :undoc-members: + :show-inheritance: + Model-based ----------- diff --git a/docs/index.rst b/docs/index.rst index 08ed324..72f479d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index 24840f2..eb49a02 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -67,4 +67,35 @@ We test our BCQ implementation on two example tasks (different from author's ver | Task | Online DQN | Behavioral | BCQ | | ---------------------- | ---------- | ---------- | --------------------------------- | | PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) | -| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | \ No newline at end of file +| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | + +# CQL + +To running CQL algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the above QRDQN section; +- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`. + +We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +We reduce the size of the offline data to 10% and 1% of the above and get: + +Buffer size 100000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | + +Buffer size 10000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | +| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index f16db46..142edd8 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -135,6 +135,8 @@ def test_discrete_bcq(args=get_args()): result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py new file mode 100644 index 0000000..ff86506 --- /dev/null +++ b/examples/atari/atari_cql.py @@ -0,0 +1,147 @@ +import os +import torch +import pickle +import pprint +import datetime +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteCQLPolicy +from tianshou.data import Collector, VectorReplayBuffer + +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--min-q-weight", type=float, default=10.) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--watch", default=False, action="store_true", + help="watch the play of pre-trained policy only") + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_PongNoFrameskip-v4.hdf5") + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + args = parser.parse_known_args()[0] + return args + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_discrete_cql(args=get_args()): + # envs + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + net = QRDQN(*args.state_shape, args.action_shape, + args.num_quantiles, args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = DiscreteCQLPolicy( + net, optim, args.gamma, args.num_quantiles, args.n_step, + args.target_update_freq, min_q_weight=args.min_q_weight + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_qrdqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith('.pkl'): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith('.hdf5'): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + + # collector + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + log_path = os.path.join( + args.logdir, args.task, 'cql', + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, update_interval=args.log_interval) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return False + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + pprint.pprint(result) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + result = offline_trainer( + policy, buffer, test_collector, args.epoch, + args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_discrete_cql(get_args()) diff --git a/test/base/test_env.py b/test/base/test_env.py index ef9474e..cc1dc84 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -118,7 +118,7 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): pass_check = 0 break total_pass += pass_check - if sys.platform != "darwin": # macOS cannot pass this check + if sys.platform == "linux": # Windows/macOS may not pass this check assert total_pass >= 2 diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 2847ac2..27c6d65 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,7 @@ import os import gym import torch +import pickle import pprint import argparse import numpy as np @@ -41,6 +42,9 @@ def get_args(): action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--save-buffer-name', type=str, + default="./expert_QRDQN_CartPole-v0.pkl") parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -130,6 +134,14 @@ def test_qrdqn(args=get_args()): rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + # save buffer in pickle format, for imitation learning unittest + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) + policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper + collector = Collector(policy, test_envs, buf, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + pickle.dump(buf, open(args.save_buffer_name, "wb")) + print(result["rews"].mean()) + def test_pqrdqn(args=get_args()): args.prioritized_replay = True diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py new file mode 100644 index 0000000..eb7de42 --- /dev/null +++ b/test/discrete/test_qrdqn_il_cql.py @@ -0,0 +1,110 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteCQLPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=7e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--min-q-weight", type=float, default=10.) + parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[64, 64]) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_QRDQN_CartPole-v0.pkl", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def test_discrete_cql(args=get_args()): + # envs + env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax=False, num_atoms=args.num_quantiles) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + + policy = DiscreteCQLPolicy( + net, optim, args.gamma, args.num_quantiles, args.n_step, + args.target_update_freq, min_q_weight=args.min_q_weight + ).to(args.device) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run test_dqn.py first to get expert's data buffer." + buffer = pickle.load(open(args.load_buffer_name, "rb")) + + # collector + test_collector = Collector(policy, test_envs, exploration_noise=True) + + log_path = os.path.join(args.logdir, args.task, 'discrete_cql') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) + + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == "__main__": + test_discrete_cql(get_args()) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 7ac7601..f9349ff 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -299,7 +299,7 @@ class BaseVectorEnv(gym.Env): clip_max = 10.0 # this magic number is from openai baselines # see baselines/common/vec_env/vec_normalize.py#L10 obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps) - obs = np.clip(obs, -clip_max, clip_max) # type: ignore + obs = np.clip(obs, -clip_max, clip_max) return obs diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index f0177d9..30b1af9 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -14,6 +14,7 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -35,6 +36,7 @@ __all__ = [ "DiscreteSACPolicy", "ImitationPolicy", "DiscreteBCQPolicy", + "DiscreteCQLPolicy", "PSRLPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index dff0f85..19bdd1d 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -162,7 +162,7 @@ class BasePolicy(ABC, nn.Module): isinstance(act, np.ndarray): # currently this action mapping only supports np.ndarray action if self.action_bound_method == "clip": - act = np.clip(act, -1.0, 1.0) # type: ignore + act = np.clip(act, -1.0, 1.0) elif self.action_bound_method == "tanh": act = np.tanh(act) if self.action_scaling: diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py new file mode 100644 index 0000000..c6e1b50 --- /dev/null +++ b/tianshou/policy/imitation/discrete_cql.py @@ -0,0 +1,78 @@ +import torch +import numpy as np +import torch.nn.functional as F +from typing import Any, Dict + +from tianshou.policy import QRDQNPolicy +from tianshou.data import Batch, to_torch + + +class DiscreteCQLPolicy(QRDQNPolicy): + """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. Default to 200. + :param int estimation_step: the number of steps to look ahead. Default to 1. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param float min_q_weight: the weight for the cql loss. + + .. seealso:: + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + min_q_weight: float = 10.0, + **kwargs: Any, + ) -> None: + super().__init__(model, optim, discount_factor, num_quantiles, estimation_step, + target_update_freq, reward_normalization, **kwargs) + self._min_q_weight = min_q_weight + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + all_dist = self(batch).logits + act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) + curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = (u * ( + self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() + ).abs()).sum(-1).mean(1) + qr_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + # add CQL loss + q = self.compute_q_value(all_dist, None) + dataset_expec = q.gather(1, act.unsqueeze(1)).mean() + negative_sampling = q.logsumexp(1).mean() + min_q_loss = negative_sampling - dataset_expec + loss = qr_loss + min_q_loss * self._min_q_weight + loss.backward() + self.optim.step() + self._iter += 1 + return { + "loss": loss.item(), + "loss/qr": qr_loss.item(), + "loss/cql": min_q_loss.item(), + } diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 01c77cb..a778173 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -124,7 +124,7 @@ class TRPOPolicy(NPGPolicy): " are poor and need to be changed.") # optimize citirc - for _ in range(self._optim_critic_iters): # type: ignore + for _ in range(self._optim_critic_iters): value = self.critic(b.obs).flatten() vf_loss = F.mse_loss(b.returns, value) self.optim.zero_grad() diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index 1ff1e00..e0d0676 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -91,5 +91,5 @@ class RunningMeanStd(object): m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count new_var = m_2 / total_count - self.mean, self.var = new_mean, new_var # type: ignore + self.mean, self.var = new_mean, new_var self.count = total_count