259 lines
7.4 KiB
Python
259 lines
7.4 KiB
Python
import gym
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import random
|
|
import tempfile
|
|
import zipfile
|
|
|
|
|
|
def zipsame(*seqs):
|
|
L = len(seqs[0])
|
|
assert all(len(seq) == L for seq in seqs[1:])
|
|
return zip(*seqs)
|
|
|
|
|
|
def unpack(seq, sizes):
|
|
"""
|
|
Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'.
|
|
None = just one bare element, not a list
|
|
|
|
Example:
|
|
unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6])
|
|
"""
|
|
seq = list(seq)
|
|
it = iter(seq)
|
|
assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes)
|
|
for size in sizes:
|
|
if size is None:
|
|
yield it.__next__()
|
|
else:
|
|
li = []
|
|
for _ in range(size):
|
|
li.append(it.__next__())
|
|
yield li
|
|
|
|
|
|
class EzPickle(object):
|
|
"""Objects that are pickled and unpickled via their constructor
|
|
arguments.
|
|
|
|
Example usage:
|
|
|
|
class Dog(Animal, EzPickle):
|
|
def __init__(self, furcolor, tailkind="bushy"):
|
|
Animal.__init__()
|
|
EzPickle.__init__(furcolor, tailkind)
|
|
...
|
|
|
|
When this object is unpickled, a new Dog will be constructed by passing the provided
|
|
furcolor and tailkind into the constructor. However, philosophers are still not sure
|
|
whether it is still the same dog.
|
|
|
|
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
|
|
and Atari.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._ezpickle_args = args
|
|
self._ezpickle_kwargs = kwargs
|
|
|
|
def __getstate__(self):
|
|
return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
|
|
|
|
def __setstate__(self, d):
|
|
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
|
|
self.__dict__.update(out.__dict__)
|
|
|
|
|
|
def set_global_seeds(i):
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
tf.set_random_seed(i)
|
|
np.random.seed(i)
|
|
random.seed(i)
|
|
|
|
|
|
def pretty_eta(seconds_left):
|
|
"""Print the number of seconds in human readable format.
|
|
|
|
Examples:
|
|
2 days
|
|
2 hours and 37 minutes
|
|
less than a minute
|
|
|
|
Paramters
|
|
---------
|
|
seconds_left: int
|
|
Number of seconds to be converted to the ETA
|
|
Returns
|
|
-------
|
|
eta: str
|
|
String representing the pretty ETA.
|
|
"""
|
|
minutes_left = seconds_left // 60
|
|
seconds_left %= 60
|
|
hours_left = minutes_left // 60
|
|
minutes_left %= 60
|
|
days_left = hours_left // 24
|
|
hours_left %= 24
|
|
|
|
def helper(cnt, name):
|
|
return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
|
|
|
|
if days_left > 0:
|
|
msg = helper(days_left, 'day')
|
|
if hours_left > 0:
|
|
msg += ' and ' + helper(hours_left, 'hour')
|
|
return msg
|
|
if hours_left > 0:
|
|
msg = helper(hours_left, 'hour')
|
|
if minutes_left > 0:
|
|
msg += ' and ' + helper(minutes_left, 'minute')
|
|
return msg
|
|
if minutes_left > 0:
|
|
return helper(minutes_left, 'minute')
|
|
return 'less than a minute'
|
|
|
|
|
|
class RunningAvg(object):
|
|
def __init__(self, gamma, init_value=None):
|
|
"""Keep a running estimate of a quantity. This is a bit like mean
|
|
but more sensitive to recent changes.
|
|
|
|
Parameters
|
|
----------
|
|
gamma: float
|
|
Must be between 0 and 1, where 0 is the most sensitive to recent
|
|
changes.
|
|
init_value: float or None
|
|
Initial value of the estimate. If None, it will be set on the first update.
|
|
"""
|
|
self._value = init_value
|
|
self._gamma = gamma
|
|
|
|
def update(self, new_val):
|
|
"""Update the estimate.
|
|
|
|
Parameters
|
|
----------
|
|
new_val: float
|
|
new observated value of estimated quantity.
|
|
"""
|
|
if self._value is None:
|
|
self._value = new_val
|
|
else:
|
|
self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
|
|
|
|
def __float__(self):
|
|
"""Get the current estimate"""
|
|
return self._value
|
|
|
|
def boolean_flag(parser, name, default=False, help=None):
|
|
"""Add a boolean flag to argparse parser.
|
|
|
|
Parameters
|
|
----------
|
|
parser: argparse.Parser
|
|
parser to add the flag to
|
|
name: str
|
|
--<name> will enable the flag, while --no-<name> will disable it
|
|
default: bool or None
|
|
default value of the flag
|
|
help: str
|
|
help string for the flag
|
|
"""
|
|
dest = name.replace('-', '_')
|
|
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
|
|
parser.add_argument("--no-" + name, action="store_false", dest=dest)
|
|
|
|
|
|
def get_wrapper_by_name(env, classname):
|
|
"""Given an a gym environment possibly wrapped multiple times, returns a wrapper
|
|
of class named classname or raises ValueError if no such wrapper was applied
|
|
|
|
Parameters
|
|
----------
|
|
env: gym.Env of gym.Wrapper
|
|
gym environment
|
|
classname: str
|
|
name of the wrapper
|
|
|
|
Returns
|
|
-------
|
|
wrapper: gym.Wrapper
|
|
wrapper named classname
|
|
"""
|
|
currentenv = env
|
|
while True:
|
|
if classname == currentenv.class_name():
|
|
return currentenv
|
|
elif isinstance(currentenv, gym.Wrapper):
|
|
currentenv = currentenv.env
|
|
else:
|
|
raise ValueError("Couldn't find wrapper named %s" % classname)
|
|
|
|
|
|
def relatively_safe_pickle_dump(obj, path, compression=False):
|
|
"""This is just like regular pickle dump, except from the fact that failure cases are
|
|
different:
|
|
|
|
- It's never possible that we end up with a pickle in corrupted state.
|
|
- If a there was a different file at the path, that file will remain unchanged in the
|
|
even of failure (provided that filesystem rename is atomic).
|
|
- it is sometimes possible that we end up with useless temp file which needs to be
|
|
deleted manually (it will be removed automatically on the next function call)
|
|
|
|
The indended use case is periodic checkpoints of experiment state, such that we never
|
|
corrupt previous checkpoints if the current one fails.
|
|
|
|
Parameters
|
|
----------
|
|
obj: object
|
|
object to pickle
|
|
path: str
|
|
path to the output file
|
|
compression: bool
|
|
if true pickle will be compressed
|
|
"""
|
|
temp_storage = path + ".relatively_safe"
|
|
if compression:
|
|
# Using gzip here would be simpler, but the size is limited to 2GB
|
|
with tempfile.NamedTemporaryFile() as uncompressed_file:
|
|
pickle.dump(obj, uncompressed_file)
|
|
uncompressed_file.file.flush()
|
|
with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
|
|
myzip.write(uncompressed_file.name, "data")
|
|
else:
|
|
with open(temp_storage, "wb") as f:
|
|
pickle.dump(obj, f)
|
|
os.rename(temp_storage, path)
|
|
|
|
|
|
def pickle_load(path, compression=False):
|
|
"""Unpickle a possible compressed pickle.
|
|
|
|
Parameters
|
|
----------
|
|
path: str
|
|
path to the output file
|
|
compression: bool
|
|
if true assumes that pickle was compressed when created and attempts decompression.
|
|
|
|
Returns
|
|
-------
|
|
obj: object
|
|
the unpickled object
|
|
"""
|
|
|
|
if compression:
|
|
with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
|
|
with myzip.open("data") as f:
|
|
return pickle.load(f)
|
|
else:
|
|
with open(path, "rb") as f:
|
|
return pickle.load(f)
|