import os
import torch
import pickle
import pytest
import tempfile
import h5py
import numpy as np
from timeit import timeit

from tianshou.data import Batch, SegmentTree, \
    ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5

if __name__ == '__main__':
    from env import MyTestEnv
else:  # pytest
    from test.base.env import MyTestEnv


def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, [a], rew, done, obs_next, info)
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    with pytest.raises(ValueError):
        buf._add_to_buffer('rew', np.array([1, 2, 3]))
    assert buf.act.dtype == np.object
    assert isinstance(buf.act[0], list)
    data, indice = buf.sample(bufsize * 2)
    assert (indice < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
    assert b.obs[0] == 1
    assert b.done[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.done[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
    assert np.all(b.info.b.c[1:] == 0.0)
    with pytest.raises(IndexError):
        b[22]
    b = ListReplayBuffer()
    with pytest.raises(NotImplementedError):
        b.sample(0)


def test_ignore_obs_next(size=10):
    # Issue 82
    buf = ReplayBuffer(size, ignore_obs_next=True)
    for i in range(size):
        buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
                     'mask2': np.array([i + 4, 0, 1, 0, 0]),
                     'mask': i},
                act={'act_id': i,
                     'position_id': i + 3},
                rew=i,
                done=i % 3 == 0,
                info={'if': i})
    indice = np.arange(len(buf))
    orig = np.arange(len(buf))
    data = buf[indice]
    data2 = buf[indice]
    assert isinstance(data, Batch)
    assert isinstance(data2, Batch)
    assert np.allclose(indice, orig)
    assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
    assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
    buf.stack_num = 4
    data = buf[indice]
    data2 = buf[indice]
    assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
    assert np.allclose(data.obs_next.mask, np.array([
        [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
        [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
        [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
    assert np.allclose(data.info['if'], data2.info['if'])
    assert np.allclose(data.info['if'], np.array([
        [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
        [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
        [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
    assert data.obs_next


def test_stack(size=5, bufsize=9, stack_num=4):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize, stack_num=stack_num)
    buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
    buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
    obs = env.reset(1)
    for i in range(16):
        obs_next, rew, done, info = env.step(1)
        buf.add(obs, 1, rew, done, None, info)
        buf2.add(obs, 1, rew, done, None, info)
        buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
        obs = obs_next
        if done:
            obs = env.reset(1)
    indice = np.arange(len(buf))
    assert np.allclose(buf.get(indice, 'obs')[..., 0], [
        [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
        [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
        [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
    assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
    assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
    _, indice = buf2.sample(0)
    assert indice.tolist() == [2, 6]
    _, indice = buf2.sample(1)
    assert indice in [2, 6]
    with pytest.raises(IndexError):
        buf[bufsize * 2]


def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
        obs = obs_next
        data, indice = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
    data, indice = buf.sample(len(buf) // 2)
    buf.update_weight(indice, -data.weight / 2)
    assert np.allclose(
        buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)


def test_update():
    buf1 = ReplayBuffer(4, stack_num=2)
    buf2 = ReplayBuffer(4, stack_num=2)
    for i in range(5):
        buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
                 done=i % 2 == 0, info={'incident': 'found'})
    assert len(buf1) > len(buf2)
    buf2.update(buf1)
    assert len(buf1) == len(buf2)
    assert (buf2[0].obs == buf1[1].obs).all()
    assert (buf2[-1].obs == buf1[0].obs).all()


def test_segtree():
    realop = np.sum
    # small test
    actual_len = 8
    tree = SegmentTree(actual_len)  # 1-15. 8-15 are leaf nodes
    assert len(tree) == actual_len
    assert np.all([tree[i] == 0. for i in range(actual_len)])
    with pytest.raises(IndexError):
        tree[actual_len]
    naive = np.zeros([actual_len])
    for _ in range(1000):
        # random choose a place to perform single update
        index = np.random.randint(actual_len)
        value = np.random.rand()
        naive[index] = value
        tree[index] = value
        for i in range(actual_len):
            for j in range(i + 1, actual_len):
                ref = realop(naive[i:j])
                out = tree.reduce(i, j)
                assert np.allclose(ref, out), (ref, out)
    assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
    assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
    # batch setitem
    for _ in range(1000):
        index = np.random.choice(actual_len, size=4)
        value = np.random.rand(4)
        naive[index] = value
        tree[index] = value
        assert np.allclose(realop(naive), tree.reduce())
        for i in range(10):
            left = np.random.randint(actual_len)
            right = np.random.randint(left + 1, actual_len + 1)
            assert np.allclose(realop(naive[left:right]),
                               tree.reduce(left, right))
    # large test
    actual_len = 16384
    tree = SegmentTree(actual_len)
    naive = np.zeros([actual_len])
    for _ in range(1000):
        index = np.random.choice(actual_len, size=64)
        value = np.random.rand(64)
        naive[index] = value
        tree[index] = value
        assert np.allclose(realop(naive), tree.reduce())
        for i in range(10):
            left = np.random.randint(actual_len)
            right = np.random.randint(left + 1, actual_len + 1)
            assert np.allclose(realop(naive[left:right]),
                               tree.reduce(left, right))

    # test prefix-sum-idx
    actual_len = 8
    tree = SegmentTree(actual_len)
    naive = np.random.rand(actual_len)
    tree[np.arange(actual_len)] = naive
    for _ in range(1000):
        scalar = np.random.rand() * naive.sum()
        index = tree.get_prefix_sum_idx(scalar)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
    # corner case here
    naive = np.ones(actual_len, np.int)
    tree[np.arange(actual_len)] = naive
    for scalar in range(actual_len):
        index = tree.get_prefix_sum_idx(scalar * 1.)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
    tree = SegmentTree(10)
    tree[np.arange(3)] = np.array([0.1, 0, 0.1])
    assert np.allclose(tree.get_prefix_sum_idx(
        np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
    with pytest.raises(AssertionError):
        tree.get_prefix_sum_idx(.2)
    # test large prefix-sum-idx
    actual_len = 16384
    tree = SegmentTree(actual_len)
    naive = np.random.rand(actual_len)
    tree[np.arange(actual_len)] = naive
    for _ in range(1000):
        scalar = np.random.rand() * naive.sum()
        index = tree.get_prefix_sum_idx(scalar)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()

    # profile
    if __name__ == '__main__':
        size = 100000
        bsz = 64
        naive = np.random.rand(size)
        tree = SegmentTree(size)
        tree[np.arange(size)] = naive

        def sample_npbuf():
            return np.random.choice(size, bsz, p=naive / naive.sum())

        def sample_tree():
            scalar = np.random.rand(bsz) * tree.reduce()
            return tree.get_prefix_sum_idx(scalar)

        print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
        print('tree', timeit(sample_tree, setup=sample_tree, number=1000))


def test_pickle():
    size = 100
    vbuf = ReplayBuffer(size, stack_num=2)
    lbuf = ListReplayBuffer()
    pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rew = torch.tensor([1.]).to(device)
    for i in range(4):
        vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
    for i in range(3):
        lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
    for i in range(5):
        pbuf.add(obs=Batch(index=np.array([i])),
                 act=2, rew=rew, done=0, weight=np.random.rand())
    # save & load
    _vbuf = pickle.loads(pickle.dumps(vbuf))
    _lbuf = pickle.loads(pickle.dumps(lbuf))
    _pbuf = pickle.loads(pickle.dumps(pbuf))
    assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
    assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
    assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
    # make sure the meta var is identical
    assert _vbuf.stack_num == vbuf.stack_num
    assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
                       pbuf.weight[np.arange(len(pbuf))])


def test_hdf5():
    size = 100
    buffers = {
        "array": ReplayBuffer(size, stack_num=2),
        "list": ListReplayBuffer(),
        "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rew = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': rew,
            'done': 0,
            'info': {"number": {"n": i}, 'extra': None},
        }
        buffers["array"].add(**kwargs)
        buffers["list"].add(**kwargs)
        buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k]._maxsize == buffers[k]._maxsize
        assert _buffers[k]._index == buffers[k]._index
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    for k in ["array", "prioritized"]:
        assert isinstance(buffers[k].get(0, "info"), Batch)
        assert isinstance(_buffers[k].get(0, "info"), Batch)
    for k in ["array"]:
        assert np.all(
            buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
        assert np.all(
            buffers[k][:].info.extra == _buffers[k][:].info.extra)

    for path in paths.values():
        os.remove(path)

    # raise exception when value cannot be pickled
    data = {"not_supported": lambda x: x*x}
    grp = h5py.Group
    with pytest.raises(NotImplementedError):
        to_hdf5(data, grp)
    # ndarray with data type not supported by HDF5 that cannot be pickled
    data = {"not_supported": np.array(lambda x: x*x)}
    grp = h5py.Group
    with pytest.raises(RuntimeError):
        to_hdf5(data, grp)


if __name__ == '__main__':
    test_hdf5()
    test_replaybuffer()
    test_ignore_obs_next()
    test_stack()
    test_pickle()
    test_segtree()
    test_priortized_replaybuffer()
    test_priortized_replaybuffer(233333, 200000)
    test_update()