Source code for tensorlayer.db

#! /usr/bin/python
# -*- coding: utf-8 -*-

import pickle
import time
import os
import sys
from datetime import datetime

import gridfs
import pymongo
from tensorlayer.files import load_graph_and_params, exists_or_mkdir, del_folder
from tensorlayer import logging

import tensorflow as tf
import numpy as np


[docs]class TensorHub(object): """It is a MongoDB based manager that help you to manage data, network architecture, parameters and logging. Parameters ------------- ip : str Localhost or IP address. port : int Port number. dbname : str Database name. username : str or None User name, set to None if you do not need authentication. password : str Password. project_name : str or None Experiment key for this entire project, similar with the repository name of Github. Attributes ------------ ip, port, dbname and other input parameters : see above See above. project_name : str The given project name, if no given, set to the script name. db : mongodb client See ``pymongo.MongoClient``. """ # @deprecated_alias(db_name='dbname', user_name='username', end_support_version=2.1) def __init__( self, ip='localhost', port=27017, dbname='dbname', username='None', password='password', project_name=None ): self.ip = ip self.port = port self.dbname = dbname self.username = username print("[Database] Initializing ...") # connect mongodb client = pymongo.MongoClient(ip, port) self.db = client[dbname] if username is None: print(username, password) self.db.authenticate(username, password) else: print("[Database] No username given, it works if authentication is not required") if project_name is None: self.project_name = sys.argv[0].split('.')[0] print("[Database] No project_name given, use {}".format(self.project_name)) else: self.project_name = project_name # define file system (Buckets) self.dataset_fs = gridfs.GridFS(self.db, collection="datasetFilesystem") self.model_fs = gridfs.GridFS(self.db, collection="modelfs") # self.params_fs = gridfs.GridFS(self.db, collection="parametersFilesystem") # self.architecture_fs = gridfs.GridFS(self.db, collection="architectureFilesystem") print("[Database] Connected ") _s = "[Database] Info:\n" _s += " ip : {}\n".format(self.ip) _s += " port : {}\n".format(self.port) _s += " dbname : {}\n".format(self.dbname) _s += " username : {}\n".format(self.username) _s += " password : {}\n".format("*******") _s += " project_name : {}\n".format(self.project_name) self._s = _s print(self._s) def __str__(self): """Print information of databset.""" return self._s def _fill_project_info(self, args): """Fill in project_name for all studies, architectures and parameters.""" return args.update({'project_name': self.project_name}) @staticmethod def _serialization(ps): """Serialize data.""" return pickle.dumps(ps, protocol=pickle.HIGHEST_PROTOCOL) # protocol=2) # with open('_temp.pkl', 'wb') as file: # return pickle.dump(ps, file, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def _deserialization(ps): """Deseralize data.""" return pickle.loads(ps) # =========================== MODELS ================================
[docs] def save_model(self, network=None, model_name='model', **kwargs): """Save model architecture and parameters into database, timestamp will be added automatically. Parameters ---------- network : TensorLayer layer TensorLayer layer instance. model_name : str The name/key of model. kwargs : other events Other events, such as name, accuracy, loss, step number and etc (optinal). Examples --------- Save model architecture and parameters into database. >>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model') Load one model with parameters from database (run this in other script) >>> net = db.find_top_model(sess=sess, accuracy=0.8, loss=2.3) Find and load the latest model. >>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.DESCENDING)]) >>> net = db.find_top_model(sess=sess, sort=[("time", -1)]) Find and load the oldest model. >>> net = db.find_top_model(sess=sess, sort=[("time", pymongo.ASCENDING)]) >>> net = db.find_top_model(sess=sess, sort=[("time", 1)]) Get model information >>> net._accuracy ... 0.8 Returns --------- boolean : True for success, False for fail. """ kwargs.update({'model_name': model_name}) self._fill_project_info(kwargs) # put project_name into kwargs params = network.get_all_params() s = time.time() kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()}) try: params_id = self.model_fs.put(self._serialization(params)) kwargs.update({'params_id': params_id, 'time': datetime.utcnow()}) self.db.Model.insert_one(kwargs) print("[Database] Save model: SUCCESS, took: {}s".format(round(time.time() - s, 2))) return True except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) print("[Database] Save model: FAIL") return False
[docs] def find_top_model(self, sess, sort=None, model_name='model', **kwargs): """Finds and returns a model architecture and its parameters from the database which matches the requirement. Parameters ---------- sess : Session TensorFlow session. sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. model_name : str or None The name/key of model. kwargs : other events Other events, such as name, accuracy, loss, step number and etc (optinal). Examples --------- - see ``save_model``. Returns --------- network : TensorLayer layer Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``. """ # print(kwargs) # {} kwargs.update({'model_name': model_name}) self._fill_project_info(kwargs) s = time.time() d = self.db.Model.find_one(filter=kwargs, sort=sort) _temp_file_name = '_find_one_model_ztemp_file' if d is not None: params_id = d['params_id'] graphs = d['architecture'] _datetime = d['time'] exists_or_mkdir(_temp_file_name, False) with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file: pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL) else: print("[Database] FAIL! Cannot find model: {}".format(kwargs)) return False try: params = self._deserialization(self.model_fs.get(params_id).read()) np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params) network = load_graph_and_params(name=_temp_file_name, sess=sess) del_folder(_temp_file_name) pc = self.db.Model.find(kwargs) print( "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format( kwargs, sort, _datetime, round(time.time() - s, 2) ) ) # put all informations of model into the TL layer for key in d: network.__dict__.update({"_%s" % key: d[key]}) # check whether more parameters match the requirement params_id_list = pc.distinct('params_id') n_params = len(params_id_list) if n_params != 1: print(" Note that there are {} models match the kwargs".format(n_params)) return network except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) return False
[docs] def delete_model(self, **kwargs): """Delete model. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. """ self._fill_project_info(kwargs) self.db.Model.delete_many(kwargs) logging.info("[Database] Delete Model SUCCESS")
# =========================== DATASET ===============================
[docs] def save_dataset(self, dataset=None, dataset_name=None, **kwargs): """Saves one dataset into database, timestamp will be added automatically. Parameters ---------- dataset : any type The dataset you want to store. dataset_name : str The name of dataset. kwargs : other events Other events, such as description, author and etc (optinal). Examples ---------- Save dataset >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial') Get dataset >>> dataset = db.find_top_dataset('mnist') Returns --------- boolean : Return True if save success, otherwise, return False. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() try: dataset_id = self.dataset_fs.put(self._serialization(dataset)) kwargs.update({'dataset_id': dataset_id, 'time': datetime.utcnow()}) self.db.Dataset.insert_one(kwargs) # print("[Database] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2))) print("[Database] Save dataset: SUCCESS, took: {}s".format(round(time.time() - s, 2))) return True except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) print("[Database] Save dataset: FAIL") return False
[docs] def find_top_dataset(self, dataset_name=None, sort=None, **kwargs): """Finds and returns a dataset from the database which matches the requirement. Parameters ---------- dataset_name : str The name of dataset. sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. kwargs : other events Other events, such as description, author and etc (optinal). Examples --------- Save dataset >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial') Get dataset >>> dataset = db.find_top_dataset('mnist') >>> datasets = db.find_datasets('mnist') Returns -------- dataset : the dataset or False Return False if nothing found. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() d = self.db.Dataset.find_one(filter=kwargs, sort=sort) if d is not None: dataset_id = d['dataset_id'] else: print("[Database] FAIL! Cannot find dataset: {}".format(kwargs)) return False try: dataset = self._deserialization(self.dataset_fs.get(dataset_id).read()) pc = self.db.Dataset.find(kwargs) print("[Database] Find one dataset SUCCESS, {} took: {}s".format(kwargs, round(time.time() - s, 2))) # check whether more datasets match the requirement dataset_id_list = pc.distinct('dataset_id') n_dataset = len(dataset_id_list) if n_dataset != 1: print(" Note that there are {} datasets match the requirement".format(n_dataset)) return dataset except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) return False
[docs] def find_datasets(self, dataset_name=None, **kwargs): """Finds and returns all datasets from the database which matches the requirement. In some case, the data in a dataset can be stored separately for better management. Parameters ---------- dataset_name : str The name/key of dataset. kwargs : other events Other events, such as description, author and etc (optional). Returns -------- params : the parameters, return False if nothing found. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() pc = self.db.Dataset.find(kwargs) if pc is not None: dataset_id_list = pc.distinct('dataset_id') dataset_list = [] for dataset_id in dataset_id_list: # you may have multiple Buckets files tmp = self.dataset_fs.get(dataset_id).read() dataset_list.append(self._deserialization(tmp)) else: print("[Database] FAIL! Cannot find any dataset: {}".format(kwargs)) return False print("[Database] Find {} datasets SUCCESS, took: {}s".format(len(dataset_list), round(time.time() - s, 2))) return dataset_list
[docs] def delete_datasets(self, **kwargs): """Delete datasets. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. """ self._fill_project_info(kwargs) self.db.Dataset.delete_many(kwargs) logging.info("[Database] Delete Dataset SUCCESS")
# =========================== LOGGING ===============================
[docs] def save_training_log(self, **kwargs): """Saves the training log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_training_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.TrainLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] train log: " + _log)
[docs] def save_validation_log(self, **kwargs): """Saves the validation log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_validation_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.ValidLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] valid log: " + _log)
[docs] def save_testing_log(self, **kwargs): """Saves the testing log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_testing_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.TestLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] test log: " + _log)
[docs] def delete_training_log(self, **kwargs): """Deletes training log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- Save training log >>> db.save_training_log(accuracy=0.33) >>> db.save_training_log(accuracy=0.44) Delete logs that match the requirement >>> db.delete_training_log(accuracy=0.33) Delete all logs >>> db.delete_training_log() """ self._fill_project_info(kwargs) self.db.TrainLog.delete_many(kwargs) logging.info("[Database] Delete TrainLog SUCCESS")
[docs] def delete_validation_log(self, **kwargs): """Deletes validation log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- - see ``save_training_log``. """ self._fill_project_info(kwargs) self.db.ValidLog.delete_many(kwargs) logging.info("[Database] Delete ValidLog SUCCESS")
[docs] def delete_testing_log(self, **kwargs): """Deletes testing log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- - see ``save_training_log``. """ self._fill_project_info(kwargs) self.db.TestLog.delete_many(kwargs) logging.info("[Database] Delete TestLog SUCCESS")
# def find_training_logs(self, **kwargs): # pass # # def find_validation_logs(self, **kwargs): # pass # # def find_testing_logs(self, **kwargs): # pass # =========================== Task ===================================
[docs] def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_result_keys=None, **kwargs): """Uploads a task to the database, timestamp will be added automatically. Parameters ----------- task_name : str The task name. script : str File name of the python script. hyper_parameters : dictionary The hyper parameters pass into the script. saved_result_keys : list of str The keys of the task results to keep in the database when the task finishes. kwargs : other parameters Users customized parameters such as description, version number. Examples ----------- Uploads a task >>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial') Finds and runs the latest task >>> db.run_top_task(sess=sess, sort=[("time", pymongo.DESCENDING)]) >>> db.run_top_task(sess=sess, sort=[("time", -1)]) Finds and runs the oldest task >>> db.run_top_task(sess=sess, sort=[("time", pymongo.ASCENDING)]) >>> db.run_top_task(sess=sess, sort=[("time", 1)]) """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") if not isinstance(script, str): # is None: raise Exception("script should be string") if hyper_parameters is None: hyper_parameters = {} if saved_result_keys is None: saved_result_keys = [] self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) kwargs.update({'hyper_parameters': hyper_parameters}) kwargs.update({'saved_result_keys': saved_result_keys}) _script = open(script, 'rb').read() kwargs.update({'status': 'pending', 'script': _script, 'result': {}}) self.db.Task.insert_one(kwargs) logging.info("[Database] Saved Task - task_name: {} script: {}".format(task_name, script))
[docs] def run_top_task(self, task_name=None, sort=None, **kwargs): """Finds and runs a pending task that in the first of the sorting list. Parameters ----------- task_name : str The task name. sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. kwargs : other parameters Users customized parameters such as description, version number. Examples --------- Monitors the database and pull tasks to run >>> while True: >>> print("waiting task from distributor") >>> db.run_top_task(task_name='mnist', sort=[("time", -1)]) >>> time.sleep(1) Returns -------- boolean : True for success, False for fail. """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") self._fill_project_info(kwargs) kwargs.update({'status': 'pending'}) # find task and set status to running task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort) try: # get task info e.g. hyper parameters, python script if task is None: logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort)) return False else: logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort)) _datetime = task['time'] _script = task['script'] _id = task['_id'] _hyper_parameters = task['hyper_parameters'] _saved_result_keys = task['saved_result_keys'] logging.info(" hyper parameters:") for key in _hyper_parameters: globals()[key] = _hyper_parameters[key] logging.info(" {}: {}".format(key, _hyper_parameters[key])) # run task s = time.time() logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime)) _script = _script.decode('utf-8') with tf.Graph().as_default(): # as graph: # clear all TF graphs exec(_script, globals()) # set status to finished _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}}) # return results __result = {} for _key in _saved_result_keys: logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key]))) __result.update({"%s" % _key: globals()[_key]}) _ = self.db.Task.find_one_and_update( { '_id': _id }, {'$set': { 'result': __result }}, return_document=pymongo.ReturnDocument.AFTER ) logging.info( "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format( task_name, sort, _datetime, time.time() - s ) ) return True except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) logging.info("[Database] Fail to run task") # if fail, set status back to pending _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}}) return False
[docs] def delete_tasks(self, **kwargs): """Delete tasks. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- >>> db.delete_tasks() """ self._fill_project_info(kwargs) self.db.Task.delete_many(kwargs) logging.info("[Database] Delete Task SUCCESS")
[docs] def check_unfinished_task(self, task_name=None, **kwargs): """Finds and runs a pending task. Parameters ----------- task_name : str The task name. kwargs : other parameters Users customized parameters such as description, version number. Examples --------- Wait until all tasks finish in user's local console >>> while not db.check_unfinished_task(): >>> time.sleep(1) >>> print("all tasks finished") >>> sess = tf.InteractiveSession() >>> net = db.find_top_model(sess=sess, sort=[("test_accuracy", -1)]) >>> print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name)) Returns -------- boolean : True for success, False for fail. """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") self._fill_project_info(kwargs) kwargs.update({'$or': [{'status': 'pending'}, {'status': 'running'}]}) # ## find task # task = self.db.Task.find_one(kwargs) task = self.db.Task.find(kwargs) task_id_list = task.distinct('_id') n_task = len(task_id_list) if n_task == 0: logging.info("[Database] No unfinished task - task_name: {}".format(task_name)) return False else: logging.info("[Database] Find {} unfinished task - task_name: {}".format(n_task, task_name)) return True
@staticmethod def _print_dict(args): string = '' for key, value in args.items(): if key is not '_id': string += str(key) + ": " + str(value) + " / " return string