diff --git a/.gitmodules b/.gitmodules index 9cca4d8a7..b6daa0910 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "minerl/env/Malmo"] path = minerl/env/Malmo url = https://github.com/cmu-rl/Malmo.git +[submodule "minerl/dependencies/pySmartDL"] + path = minerl/dependencies/pySmartDL + url = https://github.com/minerllabs/pySmartDL.git diff --git a/docs/source/assets/cropped_viewer.gif b/docs/source/assets/cropped_viewer.gif new file mode 100644 index 000000000..0abb0f642 Binary files /dev/null and b/docs/source/assets/cropped_viewer.gif differ diff --git a/docs/source/assets/minerl_viewer.gif b/docs/source/assets/minerl_viewer.gif new file mode 100644 index 000000000..04188b30b Binary files /dev/null and b/docs/source/assets/minerl_viewer.gif differ diff --git a/docs/source/index.rst b/docs/source/index.rst index ffb653289..bee892294 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -76,7 +76,7 @@ components: :caption: Tutorials and Guides :maxdepth: 2 - tutorials/getting_started + tutorials/index tutorials/first_agent tutorials/data_sampling diff --git a/docs/source/tutorials/data_sampling.rst b/docs/source/tutorials/data_sampling.rst index 1a320bbb1..b8296ba67 100644 --- a/docs/source/tutorials/data_sampling.rst +++ b/docs/source/tutorials/data_sampling.rst @@ -17,25 +17,94 @@ Now we can build the datast for :code:`MineRLObtainDiamond-v0` .. code-block:: python - data = minerl.data.make('MineRLObtainDiamond-v0') + data = minerl.data.make( + 'MineRLObtainDiamond-v0', + data_dir='/your/local/path') # Iterate through a single epoch gathering sequences of at most 32 steps - for obs, rew, done, act in data.seq_iter(num_epochs=1, max_sequence_len=32): - print("Number of diffrent actions:", len(act)) - for action in act: - print(act) - print("Number of diffrent observations:", len(obs), obs) - for observation in obs: - print(obs) - print("Rewards:", rew) - print("Dones:", done) + for current_state, action, reward, next_state, done \ + in data.sarsd_iter( + num_epochs=1, max_sequence_len=32): + # Print the POV @ the first step of the sequence + print(current_state['pov'][0]) + # Print the final reward pf the sequence! + print(reward[-1]) + # Check if final (next_state) is terminal. + print(done[-1]) + # ... do something with the data. + print("At the end of trajectories the length" + "can be < max_sequence_len", len(reward)) -.. note:: + +.. warning:: The :code:`minerl` package uses environment variables to locate the data directory. For portability, plese define :code:`MINERL_DATA_ROOT` as - :code:`/your/local/path/data_texture_0_low_res` in your system environment variables. + :code:`/your/local/path/` in your system environment variables. + + + +============================================================= +Visualizing The Data :code:`minerl.viewer` +============================================================= + +To help you get familiar with the MineRL dataset, +the :code:`minerl` python package also provides a data trajectory viewer called +:code:`minerl.viewer`: + + +.. image:: ../assets/cropped_viewer.gif + :width: 90 % + :alt: + :align: center + + +The :code:`minerl.viewer` program lets you step through individual +trajectories, +showing the observation seen the player, the action +they took (including camera, movement, and any action described by an MineRL +environment's action space), and the reward they received. + +.. exec:: + + import minerl + import minerl.viewer + + help_str = minerl.viewer.parser.format_help() + + print(".. code-block:: bash\n") + for line in help_str.split("\n"): + print("\t{}".format(line)) + +**Try it out on a random trajectory by running:** + +.. code-block:: bash + + # Make sure your MINERL_DATA_ROOT is set! + export MINERL_DATA_ROOT='/your/local/path' + + # Visualizes a random trajectory of MineRLObtainDiamondDense-v0 + python3 -m minerl.viewer MineRLObtainDiamondDense-v0 + + + +**Try it out on a specific trajectory by running:** + +.. exec:: + + import minerl + import minerl.viewer + + traj_name = minerl.viewer._DOC_TRAJ_NAME + + print(".. code-block:: bash\n") + + print('\t# Make sure your MINERL_DATA_ROOT is set!') + print("\texport MINERL_DATA_ROOT='/your/local/path'") + print("\t# Visualizes a specific trajectory. {}...".format(traj_name[:17])) + print("\tpython3 -m minerl.viewer MineRLTreechop-v0 \\") + print("\t\t{}".format(traj_name)) diff --git a/docs/source/tutorials/getting_started.rst b/docs/source/tutorials/index.rst similarity index 100% rename from docs/source/tutorials/getting_started.rst rename to docs/source/tutorials/index.rst diff --git a/minerl/__init__.py b/minerl/__init__.py index 33e7c6348..174f99d5f 100644 --- a/minerl/__init__.py +++ b/minerl/__init__.py @@ -1,2 +1,5 @@ + +import minerl.dependencies import minerl.data -import minerl.env \ No newline at end of file +import minerl.env +import minerl.env.spaces as spaces diff --git a/minerl/data/__init__.py b/minerl/data/__init__.py index 82230548e..9e6ef5def 100644 --- a/minerl/data/__init__.py +++ b/minerl/data/__init__.py @@ -2,6 +2,9 @@ from minerl.data.download import download import os +from minerl.data.version import DATA_VERSION, FILE_PREFIX, VERSION_FILE_NAME + +import minerl.data.version def make(environment=None , data_dir=None, num_workers=4, worker_batch_size=32, minimum_size_to_dequeue=32, force_download=False): """ @@ -35,6 +38,9 @@ def make(environment=None , data_dir=None, num_workers=4, worker_batch_size=32, raise ValueError("No data_dir provided and $MINERL_DATA_ROOT undefined." "Specify force_download=True to download default dataset") + + minerl.data.version.assert_version(data_dir) + d = DataPipeline( os.path.join(data_dir, environment), environment, diff --git a/minerl/data/data_pipeline.py b/minerl/data/data_pipeline.py index 43a156b17..c75e00bc7 100644 --- a/minerl/data/data_pipeline.py +++ b/minerl/data/data_pipeline.py @@ -1,3 +1,4 @@ +import functools import json import logging import multiprocessing @@ -16,6 +17,8 @@ logger = logging.getLogger(__name__) +from minerl.data.version import assert_version, assert_prefix + if os.name != "nt": class WindowsError(OSError): pass @@ -53,6 +56,24 @@ def __init__(self, self.size_to_dequeue = min_size_to_dequeue self.processing_pool = multiprocessing.Pool(self.number_of_workers) + self._action_space = gym.envs.registration.spec(self.environment)._kwargs['action_space'] + self._observation_space = gym.envs.registration.spec(self.environment)._kwargs['observation_space'] + + + @property + def action_space(self): + """ + Returns: action space of current MineRL environment + """ + return self._action_space + + @property + def observation_space(self): + """ + Returns: action space of current MineRL environment + """ + return self._observation_space + # Correct way # @staticmethod # def map_to_dict(handler_list: list, target_space: gym.spaces.space): @@ -97,9 +118,21 @@ def _map_to_dict(i: int, src: list, key: str, gym_space: gym.spaces.space, dst: index = _map_to_dict(index, handler_list, key, space, result) return result - def seq_iter(self, num_epochs=-1, max_sequence_len=32, seed=None): + def seq_iter(self, num_epochs=-1, max_sequence_len=32, queue_size=None, seed=None, include_metadata=False): + """DEPRECATED METHOD FOR SAMPLING DATA FROM THE MINERL DATASET. + + This function is now :code:`DataPipeline.sarsd_iter()` """ - Returns a generator for iterating through sequences of the dataset. + raise DeprecationWarning( + "The `DataPipeline.seq_iter` method is deprecated! Please use DataPipeline.sarsd_iter()." + "\nNOTE: The new method `DataPipeline.sarsd_iter` has a different return signature! " + "\n\t Please see how to use it @ http://www.minerl.io/docs/tutorials/data_sampling.html") + + + def sarsd_iter(self, num_epochs=-1, max_sequence_len=32, queue_size=None, seed=None, include_metadata=False): + """ + Returns a generator for iterating through (state, action, reward, next_state, is_terminal) + tuples in the dataset. Loads num_workers files at once as defined in minerl.data.make() and return up to max_sequence_len consecutive samples wrapped in a dict observation space @@ -110,24 +143,28 @@ def seq_iter(self, num_epochs=-1, max_sequence_len=32, seed=None): seed (int, optional): seed for random directory walk - note, specifying seed as well as a finite num_epochs will cause the ordering of examples to be the same after every call to seq_iter - Generates: - observation_dict, reward_list, done_list, action_dict + Yields: + A tuple of (state, player_action, reward_from_action, next_state, is_next_state_terminal). + Each element is in the format of the environment action/state/reward space and contains as many + samples are requested. """ - - logger.info("Starting seq iterator on {}".format(self.data_dir)) + logger.debug("Starting seq iterator on {}".format(self.data_dir)) if seed is not None: np.random.seed(seed) data_list = self._get_all_valid_recordings(self.data_dir) m = multiprocessing.Manager() - if max_sequence_len == -1: + if queue_size is not None: + max_size = queue_size + elif max_sequence_len == -1: max_size = 2*self.number_of_workers else: max_size = 16*self.number_of_workers data_queue = m.Queue(maxsize=max_size) + logger.debug(str(self.number_of_workers) + str(max_size)) # Setup arguments for the workers. - files = [(file_dir, max_sequence_len, data_queue) for file_dir in data_list] + files = [(file_dir, max_sequence_len, data_queue, 0, include_metadata) for file_dir in data_list] epoch = 0 @@ -137,7 +174,6 @@ def seq_iter(self, num_epochs=-1, max_sequence_len=32, seed=None): # for arg1, arg2, arg3 in files: # DataPipeline._load_data_pyfunc(arg1, arg2, arg3) # break - map_promise = self.processing_pool.starmap_async(DataPipeline._load_data_pyfunc, files) # random_queue = PriorityQueue(maxsize=pool_size) @@ -146,43 +182,75 @@ def seq_iter(self, num_epochs=-1, max_sequence_len=32, seed=None): while True: try: sequence = data_queue.get_nowait() - action_batch, observation_batch, reward_batch, done_batch = sequence + if include_metadata: + observation_seq, action_seq, reward_seq, next_observation_seq, done_seq, meta = sequence + else: + observation_seq, action_seq, reward_seq, next_observation_seq, done_seq = sequence # Wrap in dict gym_spec = gym.envs.registration.spec(self.environment) - action_dict = self.map_to_dict(action_batch, gym_spec._kwargs['action_space']) - observation_dict = self.map_to_dict(observation_batch, gym_spec._kwargs['observation_space']) - - yield observation_dict, reward_batch[0], done_batch[0], action_dict + observation_dict = DataPipeline.map_to_dict(observation_seq, gym_spec._kwargs['observation_space']) + action_dict = DataPipeline.map_to_dict(action_seq, gym_spec._kwargs['action_space']) + next_observation_dict = DataPipeline.map_to_dict(next_observation_seq, gym_spec._kwargs['observation_space']) + + if include_metadata: + yield observation_dict, action_dict, reward_seq[0], next_observation_dict, done_seq[0], meta + else: + yield observation_dict, action_dict, reward_seq[0], next_observation_dict, done_seq[0] + except Empty: if map_promise.ready(): epoch += 1 break else: time.sleep(0.1) + logger.debug("Epoch complete.") - logger.info("Epoch complete.") + def load_data(self, stream_name: str, skip_interval=0, include_metadata=False): + """Iterates over an individual trajectory named stream_name. + + Args: + stream_name (str): The stream name desired to be iterated through. + skip_interval (int, optional): How many sices should be skipped.. Defaults to 0. + include_metadata (bool, optional): Whether or not meta data about the loaded trajectory should be included.. Defaults to False. - @staticmethod - def load_data(file_dir: str, environment: str, skip_interval=0,): - """ - Loading mechanism for loading a trajectory from a file as a generator - :param file_dir: file path to data directory - :param environment: the environment name e.g. MineRLObtainDiamond-v0 - :param skip_interval: NOT IMPLEMENTED how many frames to skip between observations - :return: iterator over files + Yields: + A tuple of (state, player_action, reward_from_action, next_state, is_next_state_terminal). + These are tuples are yielded in order of the episode. """ - seq = DataPipeline._load_data_pyfunc(file_dir, -1, None, skip_interval=skip_interval, environment=environment) - action_seq, observation_seq, reward_seq, done_seq = seq + if '/' in stream_name: + file_dir = stream_name + else: + file_dir = os.path.join(self.data_dir, stream_name) + seq = DataPipeline._load_data_pyfunc(file_dir, -1, None, skip_interval=skip_interval, + include_metadata=include_metadata) + if include_metadata: + observation_seq, action_seq, reward_seq, next_observation_seq, done_seq, meta = seq + else: + observation_seq, action_seq, reward_seq, next_observation_seq, done_seq = seq for idx in range(len(reward_seq[0])): # Wrap in dict - gym_spec = gym.envs.registration.spec(environment) - action_dict = DataPipeline.map_to_dict(action_seq[idx], gym_spec._kwargs['action_space']) - observation_dict = DataPipeline.map_to_dict(observation_seq[idx], gym_spec._kwargs['observation_space']) - - yield observation_dict, reward_seq[idx], done_seq[idx], action_dict + action_slice = [x[idx] for x in action_seq] + observation_slice = [x[idx] for x in observation_seq] + next_observation_slice = [x[idx] for x in next_observation_seq] + gym_spec = gym.envs.registration.spec(self.environment) + action_dict = DataPipeline.map_to_dict(action_slice, gym_spec._kwargs['action_space']) + observation_dict = DataPipeline.map_to_dict(observation_slice, gym_spec._kwargs['observation_space']) + next_observation_dict = DataPipeline.map_to_dict(next_observation_slice, gym_spec._kwargs['observation_space']) + + yield_list = [observation_dict, action_dict, reward_seq[0][idx], next_observation_dict, done_seq[0][idx]] + yield yield_list + [meta] if include_metadata else yield_list + + + def get_trajectory_names(self): + """Gets all the trajectory names + + Returns: + A list of experiment names: [description] + """ + return [os.path.basename(x) for x in self._get_all_valid_recordings(self.data_dir)] ############################ # PRIVATE METHODS # @@ -204,19 +272,22 @@ def _roundrobin(*iterables): # Todo: Make data pipeline split files per push. @staticmethod - def _load_data_pyfunc(file_dir: str, max_seq_len: int, data_queue, skip_interval=0, environment=None): + def _load_data_pyfunc(file_dir: str, max_seq_len: int, data_queue, skip_interval=0, include_metadata=False): """ Enqueueing mechanism for loading a trajectory from a file onto the data_queue :param file_dir: file path to data directory :param skip_interval: Number of time steps to skip between each sample :param max_seq_len: Number of time steps in each enqueued batch :param data_queue: multiprocessing data queue, or None to return streams directly - :param environment: environment used to wrap returned data as dict, or None to return raw data + :param include_metadata: whether or not to return an additional tuple containing metadata :return: """ + logger.debug("Loading from file {}".format(file_dir)) + video_path = str(os.path.join(file_dir, 'recording.mp4')) numpy_path = str(os.path.join(file_dir, 'rendered.npz')) + meta_path = str(os.path.join(file_dir, 'metadata.json')) try: # Start video decompression @@ -225,93 +296,133 @@ def _load_data_pyfunc(file_dir: str, max_seq_len: int, data_queue, skip_interval # Load numpy file state = np.load(numpy_path, allow_pickle=True) + # Load metadata file + with open(meta_path) as file: + meta = json.load(file) + if 'stream_name' not in meta: + meta['stream_name'] = file_dir + action_dict = {key: state[key] for key in state if key.startswith('action_')} reward_vec = state['reward'] info_dict = {key: state[key] for key in state if key.startswith('observation_')} - num_states = len(reward_vec) + num_states = len(reward_vec) + 1 + + # TEMP - calculate number of frames, fastest when max_seq_len == -1 + frames = [] + ret, frame_num = True, 0 + while ret: + ret, frame = cap.read() + if ret: + cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) + frames.append(np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)) + frame_num += 1 + + max_frame_num = frame_num # int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) <- this is not correct! + if max_seq_len == -1: + stop_idx = 0 + frames = frames[frame_num - num_states:] + else: + frames = [] + frame_num, stop_idx = 0, 0 + + # Advance video capture past first i-frame to start of experiment + cap = cv2.VideoCapture(video_path) + for _ in range(max_frame_num - num_states): + ret, _ = cap.read() + frame_num += 1 + if not ret: + return None # Rendered Frames - frame_num, stop_idx = 0, 0 - max_frame_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) observables = list(info_dict.keys()).copy() observables.append('pov') actionables = list(action_dict.keys()) - # Advance video capture past first i-frame to start of experiment - for _ in range(max_frame_num - num_states): - ret, _ = cap.read() - frame_num += 1 - if not ret: - return None - # Loop through the video and construct frames # of observations to be sent via the multiprocessing queue # in chunks of worker_batch_size to the batch_iter loop. while True: ret = True - frames = [] start_idx = stop_idx # Collect up to worker_batch_size number of frames try: - while ret and frame_num < max_frame_num and (len(frames) < max_seq_len or max_seq_len == -1): + # Go until max_seq_len +1 for S_t, A_t, -> R_t, S_{t+1}, D_{t+1} + while ret and frame_num < max_frame_num and (len(frames) < max_seq_len + 1 or max_seq_len == -1): ret, frame = cap.read() if ret: cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) frames.append(np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)) frame_num += 1 except Exception as err: - print("error reading capture device:", err) + logger.error("error reading capture device:", err) raise err if len(frames) == 0: break + if frame_num == max_frame_num: + frames[-1] = frames[-2] + stop_idx = start_idx + len(frames) # print('Num frames in batch:', stop_idx - start_idx) # Load non-image data from npz - observation_data = [None for _ in observables] + current_observation_data = [None for _ in observables] action_data = [None for _ in actionables] + next_observation_data = [None for _ in observables] try: for i, key in enumerate(observables): if key == 'pov': - observation_data[i] = np.asanyarray(frames) + current_observation_data[i] = np.asanyarray(frames[:-1]) + next_observation_data[i] = np.asanyarray(frames[1:]) else: - observation_data[i] = np.asanyarray(state[key][start_idx:stop_idx]) + current_observation_data[i] = np.asanyarray(info_dict[key][start_idx:stop_idx-1]) + next_observation_data[i] = np.asanyarray(info_dict[key][start_idx+1:stop_idx]) + # We are getting S_t, A_t -> R_t, S_{t+1}, D_{t+1} so there are less actions and rewards for i, key in enumerate(actionables): - action_data[i] = np.asanyarray(state[key][start_idx:stop_idx]) + + action_data[i] = np.asanyarray(action_dict[key][start_idx: stop_idx-1]) - reward_data = np.asanyarray(reward_vec[start_idx:stop_idx], dtype=np.float32) + reward_data = np.asanyarray(reward_vec[start_idx:stop_idx-1], dtype=np.float32) done_data = [False for _ in range(stop_idx - start_idx)] if frame_num == max_frame_num: done_data[-1] = True except Exception as err: - print("error drawing batch from npz file:", err) + logger.error("error drawing batch from npz file:", err) raise err - batches = [action_data, observation_data, [reward_data], [np.array(done_data, dtype=np.bool)]] + batches = [current_observation_data, action_data, [reward_data], next_observation_data, [np.array(done_data, dtype=np.bool)]] + if include_metadata: + batches += [meta] if data_queue is None: return batches else: data_queue.put(batches) + logger.debug("Enqueued from file {}".format(file_dir)) if not ret: break + else: + frames = [] # logger.error("Finished") return None except WindowsError as e: - logger.info("Caught windows error {} - this is expected when closing the data pool".format(e)) + logger.debug("Caught windows error {} - this is expected when closing the data pool".format(e)) + return None + except BrokenPipeError: return None + except FileNotFoundError as e: + raise e except Exception as e: - logger.error("Exception \'{}\' caught on file \"{}\" by a worker of the data pipeline.".format(e, file_dir)) + logger.debug("Exception \'{}\' caught on file \"{}\" by a worker of the data pipeline.".format(e, file_dir)) return None @staticmethod @@ -322,9 +433,10 @@ def _get_all_valid_recordings(path): if os.path.isfile(path): return [] - # add dir to directorylist if it contains .txt files + # add dir to directory list if it contains .txt files if len([f for f in os.listdir(path) if f.endswith('.mp4')]) > 0: if len([f for f in os.listdir(path) if f.endswith('.npz')]) > 0: + assert_prefix(path) directoryList.append(path) for d in os.listdir(path): diff --git a/minerl/data/download.py b/minerl/data/download.py index 3f9181895..5290b66e8 100644 --- a/minerl/data/download.py +++ b/minerl/data/download.py @@ -3,12 +3,21 @@ from urllib.error import HTTPError import requests +import shutil import tarfile -import pySmartDL +import minerl + + +from minerl.dependencies.pySmartDL import pySmartDL + import logging +from minerl.data.version import VERSION_FILE_NAME, DATA_VERSION, assert_version -def download(directory: os.path, resolution: str = 'low', texture_pack: int = 0, update_environment_variables=True, disable_cache=False): +logger = logging.getLogger(__name__) + + +def download(directory=None, resolution='low', texture_pack= 0, update_environment_variables=True, disable_cache=False): """Downloads MineRLv0 to specified directory. If directory is None, attempts to download to $MINERL_DATA_ROOT. Raises ValueError if both are undefined. @@ -27,19 +36,36 @@ def download(directory: os.path, resolution: str = 'low', texture_pack: int = 0, elif update_environment_variables: os.environ['MINERL_DATA_ROOT'] = os.path.expanduser( os.path.expandvars(os.path.normpath(directory))) + + if os.path.exists(directory): + try: + assert_version(directory) + except RuntimeError as r: + if r.comparison == "less": + raise r + logger.error(str(r)) + logger.error("Deleting existing data and forcing a data update!") + try: + shutil.rmtree(directory) + except Exception as e: + logger.error("Could not delete {}. Do you have permission?".format(directory)) + raise e + try: + os.makedirs(directory) + except: + pass - # TODO pull JSON defining dataset URLS from webserver instead of hard-coding - # TODO add hashed to website to verify downloads for mirrors - filename, hashname = "data_texture_{}_{}_res.tar.gz".format(texture_pack, resolution), \ - "data_texture_{}_{}_res.md5".format(texture_pack, resolution) - urls = ["https://router.sneakywines.me/minerl/" + filename] - hash_url = "https://router.sneakywines.me/minerl/" + hashname + filename, hashname = "minerl_v{}/data_texture_{}_{}_res.tar.gz".format(DATA_VERSION, texture_pack, resolution), \ + "minerl_v{}/data_texture_{}_{}_res.md5".format(DATA_VERSION, texture_pack, resolution) + urls = ["https://router.sneakywines.me/" + filename] + hash_url = "https://router.sneakywines.me/" + hashname try: + logger.info("Fetching download hash ...") response = requests.get(hash_url) md5_hash = response.text except TimeoutError: - print("Timeout error while retrieving hash for requested dataset version.") + logger.error("Timeout while retrieving hash for requested dataset version. Are you connected to the internet?") return None if disable_cache: @@ -47,28 +73,38 @@ def download(directory: os.path, resolution: str = 'low', texture_pack: int = 0, else: download_path = None - obj = pySmartDL.SmartDL(urls, progress_bar=True, logger=logging.getLogger(__name__), dest=download_path) + logger.info("Verifying download hash ...") + + + obj = pySmartDL.SmartDL(urls, progress_bar=True, logger=logger, dest=download_path, threads=20, timeout=60) obj.add_hash_verification('md5', md5_hash) try: obj.start() except pySmartDL.HashFailedException: - print("Hash check failed! Is server under maintenance?") + logger.error("Hash check failed! Is server under maintenance?") return None except pySmartDL.CanceledException: - print("Download canceled by user") + logger.error("Download canceled by user") + return None + except HTTPError as e: + logger.error("HTTP error encountered when downloading - please try again") + logger.error(e.errno) return None - except HTTPError: - print("HTTP error encountered when downloading - please try again") + except URLError as e: + logger.error("URL error encountered when downloading - please try again") + logger.error(e.errno) return None - except URLError: - print("URL error encountered when downloading - please try again") + except TimeoutError as e: + logger.error("Timeout encountered when downloading - is your connection stable") + logger.error(e.errno) return None - except IOError: - print("IO error encountered when downloading - please try again") + except IOError as e: + logger.error("IO error encountered when downloading - please try again") + logger.error(e.errno) return None - logging.info('Extracting downloaded files ... ') + logging.info('Extracting downloaded files - this may take some time ') try: tf = None tf = tarfile.open(obj.get_dest(), mode="r:*") @@ -80,4 +116,9 @@ def download(directory: os.path, resolution: str = 'low', texture_pack: int = 0, if disable_cache: os.remove(obj.get_dest()) + try: + assert_version(directory) + except RuntimeError as r: + logger.error(str(r)) + return directory diff --git a/minerl/data/version.py b/minerl/data/version.py new file mode 100644 index 000000000..60d37ee4a --- /dev/null +++ b/minerl/data/version.py @@ -0,0 +1,68 @@ +import os +import re + +DATA_VERSION = 1 +FILE_PREFIX = "v{}_".format(DATA_VERSION) +VERSION_FILE_NAME = "VERSION" + +def assert_version(data_directory): + version_file = os.path.join(data_directory, VERSION_FILE_NAME) + + try: + assert os.path.exists(version_file), "more" + with open(version_file, 'r') as f: + try: + txt = int(f.read()) + except FileNotFoundError: + raise AssertionError("less") + except Exception as e: + print('VERSION number not found in data folder') + raise e + current_version = txt + + assert DATA_VERSION <= txt, "more" + assert DATA_VERSION >= txt, "less" + except AssertionError as e: + _raise_error(e, data_directory) + + +def assert_prefix(tail): + """Asserts that a file name satifies a certain prefix. + + Args: + file_name (str): The file name to test. + """ + try: + assert os.path.exists(tail), "File {} does not exist.".format(tail) + + m = re.search('v([0-9]+?)_', tail) + assert bool(m), "more" + ver = int(m.group(1)) + + assert DATA_VERSION <= ver, "more" + assert DATA_VERSION >= ver, "less" + + except AssertionError as e: + _raise_error(e) + + +def _raise_error(exception, directory=None): + comparison = str(exception) + if comparison == "more": + if directory: + dir_str = "directory={}".format(directory) + else: + dir_str = "" + e = RuntimeError( + "YOUR DATASET IS OUT OF DATE! The latest is on version v{} but yours is lower!\n\n" + "\tRe-download the data using `minerl.data.download({})`".format( + DATA_VERSION, dir_str)) + e.comparison = comparison + raise e + elif comparison == "less": + e = RuntimeError("YOUR MINERL PACKAGE IS OUT OF DATE! \n\n\tPlease upgrade with `pip3 install --upgrade minerl`") + e.comparison = comparison + raise e + else: + raise exception + diff --git a/minerl/dependencies/__init__.py b/minerl/dependencies/__init__.py new file mode 100644 index 000000000..d8d58db67 --- /dev/null +++ b/minerl/dependencies/__init__.py @@ -0,0 +1,2 @@ + +import minerl.dependencies.pySmartDL.pySmartDL \ No newline at end of file diff --git a/minerl/dependencies/pySmartDL b/minerl/dependencies/pySmartDL new file mode 160000 index 000000000..8f66988b6 --- /dev/null +++ b/minerl/dependencies/pySmartDL @@ -0,0 +1 @@ +Subproject commit 8f66988b6620384d676e1973f66518ca5663dee7 diff --git a/minerl/env/Malmo b/minerl/env/Malmo index dcd042d34..08863f5df 160000 --- a/minerl/env/Malmo +++ b/minerl/env/Malmo @@ -1 +1 @@ -Subproject commit dcd042d342ae6a26d4a00e522c477e3525ca3576 +Subproject commit 08863f5df13c5170c26fed47bf4f1cd1437280e5 diff --git a/minerl/viewer.py b/minerl/viewer.py new file mode 100644 index 000000000..9be424bb6 --- /dev/null +++ b/minerl/viewer.py @@ -0,0 +1,520 @@ +"""A module for viewing individual streams from the dataset! + +To use: +``` + python3 -m minerl.stream_viewier +``` +""" + +import argparse +from minerl.data import FILE_PREFIX + +_DOC_TRAJ_NAME="{}absolute_zucchini_basilisk-13_36805-50154".format(FILE_PREFIX) + +parser = argparse.ArgumentParser("python3 -m minerl.viewer") +parser.add_argument("environment", type=str, + help='The MineRL environment to visualize. e.g. MineRLObtainDiamondDense-v0') + +parser.add_argument("stream_name", type=str, nargs='?', default=None, + help="(optional) The name of the trajectory to visualize. " + "e.g. {}." + "".format(_DOC_TRAJ_NAME)) + + + + +if __name__=="__main__": + import pyglet + import minerl + import sys + import os + if os.name == 'nt': + import msvcrt + getch = msvcrt + else: + import getch + + import random + + + try: + from pyglet.gl import * + except ImportError as e: + raise ImportError(''' + Error occured while running `from pyglet.gl import *` + HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. + If you're running on a server, you may need a virtual frame buffer; something like this should work: + 'xvfb-run -s \"-screen 0 1400x900x24\" python ' + ''') + + + import logging + import coloredlogs + import time + import tqdm + import matplotlib + matplotlib.use('Agg') + import numpy as np + import matplotlib.pyplot as plt + + plt.style.use('dark_background') + # plt.ion() + from gym.envs.classic_control import rendering + + coloredlogs.install(logging.DEBUG) + logger = logging.getLogger(__name__) + + import pyglet.window.key as key + + def parse_args(): + return parser.parse_args() + + + class Rect: + def __init__(self, x, y, w, h, color=None): + color = (255,255,255) if color is None else color + self.set(x, y, w, h, color) + + def draw(self): + pyglet.graphics.draw(4, pyglet.gl.GL_QUADS, self._quad, self._color_str) + + def set(self, x=None, y=None, w=None, h=None, color=None): + self._x = self._x if x is None else x + self._y = self._y if y is None else y + self._w = self._w if w is None else w + self._h = self._h if h is None else h + self._color = self._color if color is None else color + self._quad = ('v2f', (self._x, self._y, + self._x + self._w, self._y, + self._x + self._w, self._y + self._h, + self._x, self._y + self._h)) + self._color_str = ['c3B', self._color + self._color + self._color + self._color] + + @property + def center(self): + return self._x + self._w//2, self._y + self._h//2 + + @property + def x(self): + return self._x + + @property + def y(self): + return self._y + + @property + def height(self): + return self._h + + @property + def width(self): + return self._w + + + class Point: + def __init__(self, x, y, radius, color=None): + color = (255,255,255) if color is None else color + self.set(x, y, radius, color) + + def draw(self): + pyglet.graphics.draw_indexed(3, pyglet.gl.GL_TRIANGLES, + [0, 1, 2], + self._vertex, + self._color_str) + + # pyglet.graphics.draw(4, pyglet.gl.GL_QUADS, self._quad, self._color_str) + + def set(self, x=None, y=None, radius=None, color=None): + self._x = self._x if x is None else x + self._y = self._y if y is None else y + self._radius = self._radius if radius is None else radius + self._color = self._color if color is None else color + + height = self._radius/0.57735026919 + # TODO THIS IS INCORRECT LOL :) It's not a true radius. + self._vertex = ('v2f', (self._x-self._radius, self._y - height/2, + self._x + self._radius, self._y -height/2, + self._x, self._y + height/2)) + self._color_str = ['c3B', self._color + self._color + self._color] + + + + class ScaledWindowImageViewer(rendering.SimpleImageViewer): + def __init__(self, width, height): + super().__init__(None, 2700) + + if width > self.maxwidth: + scale = self.maxwidth / width + width = int(scale * width) + height = int(scale * height) + self.window = pyglet.window.Window(width=width, height=height, + display=self.display, vsync=False, resizable=True) + self.window.dispatch_events() + self.window.switch_to() + self.window.flip() + self.width = width + self.height = height + self.isopen = True + + @self.window.event + def on_resize(width, height): + self.width = width + self.height = height + + @self.window.event + def on_close(): + self.isopen = False + + def blit_texture(self, arr, pos_x=0, pos_y=0, width=None, height=None): + + assert len(arr.shape) == 3, "You passed in an image with the wrong number shape" + image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], + 'RGB', arr.tobytes(), pitch=arr.shape[1]*-3) + gl.glTexParameteri(gl.GL_TEXTURE_2D, + gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) + texture = image.get_texture() + texture.width = width if width else self.width + texture.height = height if height else self.height + + texture.blit(pos_x, pos_y) # draw + + def imshow(self, arr): + self.window.clear() + self.window.switch_to() + self.window.dispatch_events() + self.blit_texture(arr) + self.window.flip() + + SZ = 35 + BIG_FONT_SIZE = int(0.5*SZ) + SMALL_FONT_SIZE=int(0.4*SZ) + SMALLER_FONT_SIZE=int(0.35*SZ) + USING_COLOR = (255,0,0,255) + CAMERA_USING_COLOR =(162, 54, 69) + CUM_SUM_SPACE = .02 + class SampleViewer(ScaledWindowImageViewer): + + def __init__(self, environment, stream_name="", instructions=None, cum_rewards=None): + super().__init__(SZ*28, SZ*14) + + self.instructions = instructions + self.f_ox, self.fo_y = SZ, SZ + self.key_labels = self.make_key_labels() + self.window.set_caption("{}: {}".format(environment, stream_name)) + self.cum_rewards = cum_rewards + + + # Set up camera control stuff. + cam_x = self.f_ox + SZ*7 + cam_y = self.fo_y + cam_size = int(SZ*5*0.8) + self.camera_rect = Rect(cam_x,cam_y, cam_size, cam_size, color=(36, 109, 94)) + self.camera_labels = [ + pyglet.text.Label('Camera Control', font_size=SMALLER_FONT_SIZE, x= cam_x + cam_size/2, y= cam_y + cam_size +2, anchor_x='center'), + pyglet.text.Label('PITCH →', font_size=SMALLER_FONT_SIZE,font_name='Courier New', x= cam_x + cam_size/2, y= cam_y - SMALLER_FONT_SIZE- 4, anchor_x='center'), + pyglet.text.Label('Y\nA\nW\n↓', font_size=SMALLER_FONT_SIZE, font_name='Courier New', multiline=True,width=1, x= cam_x - SMALLER_FONT_SIZE -2, y= cam_y +cam_size/2, anchor_x='left') + ] + self.camera_labels[-1].document.set_style(0, len(self.camera_labels[-1].document.text),{'line_spacing': SMALLER_FONT_SIZE+2} ) + self.camera_info_label = pyglet.text.Label('[0,0]', font_size=SMALLER_FONT_SIZE-1, x= cam_x + cam_size, y= cam_y, anchor_x='right', anchor_y='bottom') + self.camera_point = Point(*self.camera_rect.center, radius=SZ/4) + + if self.instructions: + self.make_instructions(environment, stream_name) + + self.keys_down = [] + @self.window.event + def on_key_press(symbol, modifier): + if symbol not in self.keys_down: + self.keys_down.append(symbol) + + @self.window.event + def on_key_release(symbol, modifier): + if symbol in self.keys_down: + self.keys_down.remove(symbol) + if self.cum_rewards is not None: + self.make_cum_reward_plotter() + + def make_instructions(self, environment, stream_name): + if len(stream_name) >= 46: + stream_name = stream_name[:44] + "..." + + self.instructions_labels = [ + pyglet.text.Label(environment, font_size=BIG_FONT_SIZE, y = self.height-SZ, x = SZ/2, anchor_x='left'), + pyglet.text.Label(stream_name, font_size=SMALLER_FONT_SIZE, font_name= 'Courier New', anchor_x='left', x=1.4*SZ, y= self.height-SZ*1.5), + pyglet.text.Label(self.instructions, multiline=True, width=12*SZ, font_size=SMALLER_FONT_SIZE, anchor_x='left', x= SZ/2, y= self.height-SZ*2.3), + ] + self.progress_label = pyglet.text.Label("", multiline=False, width=14*SZ, font_name='Courier New', font_size=SMALLER_FONT_SIZE, anchor_x='left', x= 14*SZ, y= 2) + self.progress_label.set_style('background_color', (0,0,0,255)) + self.meter = tqdm.tqdm() + + + def make_cum_reward_plotter(self): + # First let us matplot lib plot the cum rewards to an image. + # Make a random plot... + # plt.clf() + fig = plt.figure(figsize=(2,2)) + ax = fig.add_subplot(111) + + plt.subplots_adjust(left=0.0, bottom=0, right=1, top=1, wspace=0, hspace=0) + # plt.title("Cumulative Rewards") + + # fig.patch.set_visible(False) + # plt.gca().axis('off') + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + + plt.plot(self.cum_rewards) + plt.xticks([]) + plt.yticks([]) + self._total_space = len(self.cum_rewards)*(CUM_SUM_SPACE) + plt.xlim(- self._total_space, len(self.cum_rewards) + self._total_space) + + # If we haven't already shown or saved the plot, then we need to + # draw the figure first... + fig.canvas.draw() + + # Now we can save it to a numpy array. + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + # print(data) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))[:,:,:3] + + self.cum_reward_image =data + + # Create the rectangle. + width = height = int(self.camera_rect.height) + x,y = self.camera_rect.center + y += int(self.camera_rect.height//2 + SZ*2) + x -= width//2 + self.cum_reward_rect = Rect(x-1,y-1, width+2,height+2, color=(255,255,255)) + self.cum_reward_label = pyglet.text.Label( + 'Net Reward', font_size=SMALLER_FONT_SIZE, x=x+width//2, y=y+height+5, + anchor_x='center', align='center') + self.cum_reward_line = Rect(x, y, w=2, h=height, color=CAMERA_USING_COLOR) + self.cum_reward_info_label = pyglet.text.Label('', multiline=True, width=width, + font_size=SMALLER_FONT_SIZE/1.1, font_name='Courier New', x=x+3, y=y-3, anchor_x='left', anchor_y='top') + + + def update_reward_info(self, rew, step, max_step): + self.cum_reward_line.set(x=self.cum_reward_rect.x+ 1 + int( + (step + self._total_space)/(max_step + self._total_space*2)*(self.cum_reward_rect.width-2) + )) + self.cum_reward_info_label.document.text = ( + "r(t): {0:.2f}\nnet: {1:.2f}".format(rew, self.cum_rewards[step]) + ) + if rew > 0: + self.cum_reward_info_label.set_style('color', (255,0,0,255)) + else: + self.cum_reward_info_label.set_style('color', (255,255,255,255)) + + def make_key_labels(self): + keys = {} + default_params = { + "font_name": 'Courier New', + "font_size": BIG_FONT_SIZE, + "anchor_x":'center', "anchor_y":'center'} + info_text_params = { + "font_name": 'Courier New', + "font_size": SMALL_FONT_SIZE, + "anchor_y":'center' + } + fo_x, fo_y = self.f_ox, self.fo_y + o_x, o_y = fo_x + SZ*3, fo_y + SZ*2 + + keys.update( { + "forward": pyglet.text.Label('↑', x=o_x, y= o_y + SZ, **default_params), + "left": pyglet.text.Label('←', x=o_x - SZ, y= o_y + SZ/2 , **default_params), + "back": pyglet.text.Label('↓', x=o_x , y= o_y , **default_params), + "right": pyglet.text.Label('→', x=o_x + SZ, y= o_y +SZ/2, **default_params), + }) + + keys["attack"] = pyglet.text.Label('attack', x=o_x + SZ*1.5, y= o_y +SZ*1.2 ,anchor_x='center', **info_text_params) + + # sprint & sneak + + o_x, o_y = fo_x + SZ, fo_y + keys.update({ + "sprint": pyglet.text.Label('sprint', x=o_x + SZ*3.5, y= o_y, anchor_x='center', **info_text_params), + "sneak": pyglet.text.Label('sneak', x=o_x , y= o_y ,anchor_x='center', **info_text_params)}) + + # jump + o_x, o_y = fo_x + SZ*3, fo_y + SZ + keys["jump"] = pyglet.text.Label('[ JUMP ]', x=o_x, y= o_y ,anchor_x='center', **info_text_params) + + o_x, o_y = fo_x + SZ/4, fo_y + keys["place"] = pyglet.text.Label('', x=o_x, y= o_y +SZ*6, anchor_x='left', **info_text_params) + keys["craft"] = pyglet.text.Label('', x=o_x, y= o_y +SZ*5.4, anchor_x='left', **info_text_params) + keys["nearbyCraft"] = pyglet.text.Label('', x=o_x, y= o_y +SZ*4.8,anchor_x='left', **info_text_params) + keys["nearbySmelt"] = pyglet.text.Label('', x=o_x, y= o_y +SZ*4.2,anchor_x='left', **info_text_params) + + return keys + + def process_actions(self, action): + for k in self.key_labels: + self.key_labels[k].set_style('color', (128,128,128,255)) + + + for x in action: + try: + if action[x] > 0: + self.key_labels[x].set_style('color', USING_COLOR) + except: + pass + + # Update mouse poisiton. + delta_y, delta_x = action['camera'] + self.camera_info_label.document.text = "[{0:.2f},{1:.2f}]".format(float(delta_y), float(delta_x)) + delta_x = np.clip(delta_x/60, -1,1)*self.camera_rect.width/2 + delta_y = np.clip(delta_y/60,-1,1)*self.camera_rect.height/2 + center_x, center_y = self.camera_rect.center + + if abs(delta_x) > 0 or abs(delta_y) > 0: + camera_color = CAMERA_USING_COLOR + else: + camera_color = (255,255,255) + self.camera_point.set(center_x + delta_x, center_y + delta_y, color=camera_color) + # self.camera_info_label.set_style('color', list(camera_color)+ [255]) + + # self.key_labels["a"].set_style('background_color', (255,255,0,255)) + + for a, p in [ + ("place", "place "), + ("nearbyCraft", 'nearbyCraft'), + ("craft", 'craft '), + ("nearbySmelt", 'nearbySmelt') ]: + if a in action: + self.key_labels[a].set_style('font_size', SMALL_FONT_SIZE) + self.key_labels[a].document.text = "{} {}".format(p, action[a]) + else: + self.key_labels[a].document.text = "" + + + def render(self, obs,reward, done,action, step, max): + self.window.clear() + self.window.switch_to() + e = self.window.dispatch_events() + + self.blit_texture(obs["pov"], SZ*14, 0, self.width -SZ*14, self.width -SZ*14) + self.process_actions(action) + + for label in self.key_labels: + self.key_labels[label].draw() + + self.camera_rect.draw() + for label in self.camera_labels: + label.draw() + self.camera_info_label.draw() + self.camera_point.draw() + + if self.instructions: + for label in self.instructions_labels: + label.draw() + prog_str = self.meter.format_meter(step, max, 0, ncols=52) + " "*48 + + self.progress_label.document.text = prog_str + self.progress_label.draw() + + if self.cum_rewards is not None: + self.update_reward_info(reward, step,max) + self.cum_reward_label.draw() + self.cum_reward_rect.draw() + self.blit_texture(self.cum_reward_image, + self.cum_reward_rect.x+1, + self.cum_reward_rect.y+1, + width=self.cum_reward_rect.width-2, + height= self.cum_reward_rect.height-2) + self.cum_reward_line.draw() + self.cum_reward_info_label.draw() + + + self.window.flip() + + QUIT=key.Q + FORWARD=key.RIGHT + BACK=key.LEFT + SPEED_UP =key.X + SLOWE_DOWN =key.Z + FRAME_UP = key.UP + FRAME_DOWN = key.DOWN + + def main(opts): + instructions_txt = ( + "Instructions:\n" + " → - Go forward at speed \n" + " ← - Go back at speed \n" + " ↑ - Move forward 1 frame \n" + " ↓ - Move back 1 frame \n" + " X - Speed up 2X \n" + " Z - Slow down 2X \n" + " Q - Quit \n" + ) + logger.info("Welcome to the MineRL Stream viewer! \n" + instructions_txt) + + logger.info("Building data pipeline for {}".format(opts.environment)) + data = minerl.data.make(opts.environment) + + # for _ in data.seq_iter( 1, -1, None, None, include_metadata=True): + # print(_[-1]) + # pass + if opts.stream_name == None: + trajs = data.get_trajectory_names() + opts.stream_name = random.choice(trajs) + + logger.info("Loading data for {}...".format(opts.stream_name)) + data_frames = list(data.load_data(opts.stream_name, include_metadata=True)) + meta = data_frames[0][-1] + cum_rewards = np.cumsum([x[2] for x in data_frames]) + file_len = len(data_frames) + logger.info("Data loading complete!".format(opts.stream_name)) + logger.info("META DATA: {}".format(meta)) + + height, width = data.observation_space.spaces['pov'].shape[:2] + + controls_viewer = SampleViewer(opts.environment, opts.stream_name, + instructions=instructions_txt, cum_rewards=cum_rewards) + + + indicator = tqdm.tqdm(total=file_len) + key = '' + position = 0 + speed = 1 + new_position = 0 + leave = False + + + + while not leave: + indicator.update(new_position - position) + indicator.refresh() + position = new_position + + # Display video viewer + obs, action, rew, next_obs, done, meta = data_frames[position] + # print(obs['inventory']) + + # print(cum_rewards[position]) + # Display info stuff! + controls_viewer.render(obs,rew, done,action, position, len(data_frames)) + if QUIT in controls_viewer.keys_down: + leave = True + elif FORWARD in controls_viewer.keys_down: + new_position = min(position + speed, file_len -1) + elif BACK in controls_viewer.keys_down: + new_position = max(position -speed, 0) + elif FRAME_UP in controls_viewer.keys_down: + new_position = min(position + 1, len(data_frames)-1) + controls_viewer.keys_down.remove(FRAME_UP) + elif FRAME_DOWN in controls_viewer.keys_down: + new_position = max(position - 1, 0) + controls_viewer.keys_down.remove(FRAME_DOWN) + elif SPEED_UP in controls_viewer.keys_down: + speed*=2 + controls_viewer.keys_down.remove(SPEED_UP) + elif SLOWE_DOWN in controls_viewer.keys_down: + speed = max(1, speed //2) + controls_viewer.keys_down.remove(SLOWE_DOWN) + + time.sleep(0.05) + + main(parse_args()) diff --git a/requirements.txt b/requirements.txt index 8eea90d0f..f873966c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ gym>=0.12.1 opencv-python>=4.1.0.25 setuptools>=40.6.2 -tqdm>=4.31.1 +tqdm>=4.32.2 numpy>=1.16.2 requests>=2.20.0 ipython>=7.5.0 typing>=3.6.6 lxml>=4.3.3 -pySmartDL>=1.3.1 psutil>=5.6.2 -pySmartDL>=1.3.1 -Pyro4>=4.76 \ No newline at end of file +Pyro4>=4.76 +getch>=1.0; sys_platform != 'win32' and sys_platform != 'cygwin' +coloredlogs>=10.0 +matplotlib==3.0.3 diff --git a/setup.py b/setup.py index 7d75dc500..c82576bac 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ def package_files(directory): setuptools.setup( name='minerl', - version='0.1.18', + version='0.2.0', description='MineRL environment and data loader for reinforcement learning from human demonstration in Minecraft', long_description=markdown, long_description_content_type="text/markdown", diff --git a/tests/excluded/sample.py b/tests/excluded/sample.py index 60e2ed44f..25fbb2813 100644 --- a/tests/excluded/sample.py +++ b/tests/excluded/sample.py @@ -13,12 +13,12 @@ NUM_EPISODES = 4 -def step_data(environment='MineRLTreechop-v0'): +def step_data(environment='MineRLObtainDiamond-v0'): d = minerl.data.make(environment) # Iterate through batches of data counter = 0 - for act, obs, rew in itertools.islice(d.batch_iter(3, None), 600): + for obs, act, rew, next_obs, done in d.sarsd_iter(num_epochs=-1, max_sequence_len=-1, queue_size=1, seed=1234): print("Act shape:", len(act), act) print("Obs shape:", len(obs), obs) print("Rew shape:", len(rew), rew)