Source code for tensorlayer.db

#! /usr/bin/python
# -*- coding: utf-8 -*-

import os
import pickle
import sys
import time
from datetime import datetime

import numpy as np
import tensorflow as tf

import gridfs
import pymongo
from tensorlayer import logging
from tensorlayer.files import (
    assign_weights, del_folder, exists_or_mkdir, load_hdf5_to_weights, save_weights_to_hdf5, static_graph2net
)


[docs]class TensorHub(object): """It is a MongoDB based manager that help you to manage data, network architecture, parameters and logging. Parameters ------------- ip : str Localhost or IP address. port : int Port number. dbname : str Database name. username : str or None User name, set to None if you do not need authentication. password : str Password. project_name : str or None Experiment key for this entire project, similar with the repository name of Github. Attributes ------------ ip, port, dbname and other input parameters : see above See above. project_name : str The given project name, if no given, set to the script name. db : mongodb client See ``pymongo.MongoClient``. """ # @deprecated_alias(db_name='dbname', user_name='username', end_support_version=2.1) def __init__( self, ip='localhost', port=27017, dbname='dbname', username='None', password='password', project_name=None ): self.ip = ip self.port = port self.dbname = dbname self.username = username print("[Database] Initializing ...") # connect mongodb client = pymongo.MongoClient(ip, port) self.db = client[dbname] if username is None: print(username, password) self.db.authenticate(username, password) else: print("[Database] No username given, it works if authentication is not required") if project_name is None: self.project_name = sys.argv[0].split('.')[0] print("[Database] No project_name given, use {}".format(self.project_name)) else: self.project_name = project_name # define file system (Buckets) self.dataset_fs = gridfs.GridFS(self.db, collection="datasetFilesystem") self.model_fs = gridfs.GridFS(self.db, collection="modelfs") # self.params_fs = gridfs.GridFS(self.db, collection="parametersFilesystem") # self.architecture_fs = gridfs.GridFS(self.db, collection="architectureFilesystem") print("[Database] Connected ") _s = "[Database] Info:\n" _s += " ip : {}\n".format(self.ip) _s += " port : {}\n".format(self.port) _s += " dbname : {}\n".format(self.dbname) _s += " username : {}\n".format(self.username) _s += " password : {}\n".format("*******") _s += " project_name : {}\n".format(self.project_name) self._s = _s print(self._s) def __str__(self): """Print information of databset.""" return self._s def _fill_project_info(self, args): """Fill in project_name for all studies, architectures and parameters.""" return args.update({'project_name': self.project_name}) @staticmethod def _serialization(ps): """Serialize data.""" return pickle.dumps(ps, protocol=pickle.HIGHEST_PROTOCOL) # protocol=2) # with open('_temp.pkl', 'wb') as file: # return pickle.dump(ps, file, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def _deserialization(ps): """Deseralize data.""" return pickle.loads(ps) # =========================== MODELS ================================
[docs] def save_model(self, network=None, model_name='model', **kwargs): """Save model architecture and parameters into database, timestamp will be added automatically. Parameters ---------- network : TensorLayer Model TensorLayer Model instance. model_name : str The name/key of model. kwargs : other events Other events, such as name, accuracy, loss, step number and etc (optinal). Examples --------- Save model architecture and parameters into database. >>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model') Load one model with parameters from database (run this in other script) >>> net = db.find_top_model(accuracy=0.8, loss=2.3) Find and load the latest model. >>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)]) >>> net = db.find_top_model(sort=[("time", -1)]) Find and load the oldest model. >>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)]) >>> net = db.find_top_model(sort=[("time", 1)]) Get model information >>> net._accuracy ... 0.8 Returns --------- boolean : True for success, False for fail. """ kwargs.update({'model_name': model_name}) self._fill_project_info(kwargs) # put project_name into kwargs # params = network.get_all_params() params = network.all_weights s = time.time() # kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()}) kwargs.update({'architecture': network.config, 'time': datetime.utcnow()}) try: params_id = self.model_fs.put(self._serialization(params)) kwargs.update({'params_id': params_id, 'time': datetime.utcnow()}) self.db.Model.insert_one(kwargs) print("[Database] Save model: SUCCESS, took: {}s".format(round(time.time() - s, 2))) return True except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) print("[Database] Save model: FAIL") return False
[docs] def find_top_model(self, sort=None, model_name='model', **kwargs): """Finds and returns a model architecture and its parameters from the database which matches the requirement. Parameters ---------- sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. model_name : str or None The name/key of model. kwargs : other events Other events, such as name, accuracy, loss, step number and etc (optinal). Examples --------- - see ``save_model``. Returns --------- network : TensorLayer Model Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``. """ # print(kwargs) # {} kwargs.update({'model_name': model_name}) self._fill_project_info(kwargs) s = time.time() d = self.db.Model.find_one(filter=kwargs, sort=sort) # _temp_file_name = '_find_one_model_ztemp_file' if d is not None: params_id = d['params_id'] graphs = d['architecture'] _datetime = d['time'] # exists_or_mkdir(_temp_file_name, False) # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file: # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL) else: print("[Database] FAIL! Cannot find model: {}".format(kwargs)) return False try: params = self._deserialization(self.model_fs.get(params_id).read()) # TODO : restore model and load weights network = static_graph2net(graphs) assign_weights(weights=params, network=network) # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params) # # network = load_graph_and_params(name=_temp_file_name, sess=sess) # del_folder(_temp_file_name) pc = self.db.Model.find(kwargs) print( "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format( kwargs, sort, _datetime, round(time.time() - s, 2) ) ) # FIXME : not sure what's this for # put all informations of model into the TL layer # for key in d: # network.__dict__.update({"_%s" % key: d[key]}) # check whether more parameters match the requirement params_id_list = pc.distinct('params_id') n_params = len(params_id_list) if n_params != 1: print(" Note that there are {} models match the kwargs".format(n_params)) return network except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) return False
[docs] def delete_model(self, **kwargs): """Delete model. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. """ self._fill_project_info(kwargs) self.db.Model.delete_many(kwargs) logging.info("[Database] Delete Model SUCCESS")
# =========================== DATASET ===============================
[docs] def save_dataset(self, dataset=None, dataset_name=None, **kwargs): """Saves one dataset into database, timestamp will be added automatically. Parameters ---------- dataset : any type The dataset you want to store. dataset_name : str The name of dataset. kwargs : other events Other events, such as description, author and etc (optinal). Examples ---------- Save dataset >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial') Get dataset >>> dataset = db.find_top_dataset('mnist') Returns --------- boolean : Return True if save success, otherwise, return False. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() try: dataset_id = self.dataset_fs.put(self._serialization(dataset)) kwargs.update({'dataset_id': dataset_id, 'time': datetime.utcnow()}) self.db.Dataset.insert_one(kwargs) # print("[Database] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2))) print("[Database] Save dataset: SUCCESS, took: {}s".format(round(time.time() - s, 2))) return True except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) print("[Database] Save dataset: FAIL") return False
[docs] def find_top_dataset(self, dataset_name=None, sort=None, **kwargs): """Finds and returns a dataset from the database which matches the requirement. Parameters ---------- dataset_name : str The name of dataset. sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. kwargs : other events Other events, such as description, author and etc (optinal). Examples --------- Save dataset >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial') Get dataset >>> dataset = db.find_top_dataset('mnist') >>> datasets = db.find_datasets('mnist') Returns -------- dataset : the dataset or False Return False if nothing found. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() d = self.db.Dataset.find_one(filter=kwargs, sort=sort) if d is not None: dataset_id = d['dataset_id'] else: print("[Database] FAIL! Cannot find dataset: {}".format(kwargs)) return False try: dataset = self._deserialization(self.dataset_fs.get(dataset_id).read()) pc = self.db.Dataset.find(kwargs) print("[Database] Find one dataset SUCCESS, {} took: {}s".format(kwargs, round(time.time() - s, 2))) # check whether more datasets match the requirement dataset_id_list = pc.distinct('dataset_id') n_dataset = len(dataset_id_list) if n_dataset != 1: print(" Note that there are {} datasets match the requirement".format(n_dataset)) return dataset except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) return False
[docs] def find_datasets(self, dataset_name=None, **kwargs): """Finds and returns all datasets from the database which matches the requirement. In some case, the data in a dataset can be stored separately for better management. Parameters ---------- dataset_name : str The name/key of dataset. kwargs : other events Other events, such as description, author and etc (optional). Returns -------- params : the parameters, return False if nothing found. """ self._fill_project_info(kwargs) if dataset_name is None: raise Exception("dataset_name is None, please give a dataset name") kwargs.update({'dataset_name': dataset_name}) s = time.time() pc = self.db.Dataset.find(kwargs) if pc is not None: dataset_id_list = pc.distinct('dataset_id') dataset_list = [] for dataset_id in dataset_id_list: # you may have multiple Buckets files tmp = self.dataset_fs.get(dataset_id).read() dataset_list.append(self._deserialization(tmp)) else: print("[Database] FAIL! Cannot find any dataset: {}".format(kwargs)) return False print("[Database] Find {} datasets SUCCESS, took: {}s".format(len(dataset_list), round(time.time() - s, 2))) return dataset_list
[docs] def delete_datasets(self, **kwargs): """Delete datasets. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. """ self._fill_project_info(kwargs) self.db.Dataset.delete_many(kwargs) logging.info("[Database] Delete Dataset SUCCESS")
# =========================== LOGGING ===============================
[docs] def save_training_log(self, **kwargs): """Saves the training log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_training_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.TrainLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] train log: " + _log)
[docs] def save_validation_log(self, **kwargs): """Saves the validation log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_validation_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.ValidLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] valid log: " + _log)
[docs] def save_testing_log(self, **kwargs): """Saves the testing log, timestamp will be added automatically. Parameters ----------- kwargs : logging information Events, such as accuracy, loss, step number and etc. Examples --------- >>> db.save_testing_log(accuracy=0.33, loss=0.98) """ self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) _result = self.db.TestLog.insert_one(kwargs) _log = self._print_dict(kwargs) logging.info("[Database] test log: " + _log)
[docs] def delete_training_log(self, **kwargs): """Deletes training log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- Save training log >>> db.save_training_log(accuracy=0.33) >>> db.save_training_log(accuracy=0.44) Delete logs that match the requirement >>> db.delete_training_log(accuracy=0.33) Delete all logs >>> db.delete_training_log() """ self._fill_project_info(kwargs) self.db.TrainLog.delete_many(kwargs) logging.info("[Database] Delete TrainLog SUCCESS")
[docs] def delete_validation_log(self, **kwargs): """Deletes validation log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- - see ``save_training_log``. """ self._fill_project_info(kwargs) self.db.ValidLog.delete_many(kwargs) logging.info("[Database] Delete ValidLog SUCCESS")
[docs] def delete_testing_log(self, **kwargs): """Deletes testing log. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- - see ``save_training_log``. """ self._fill_project_info(kwargs) self.db.TestLog.delete_many(kwargs) logging.info("[Database] Delete TestLog SUCCESS")
# def find_training_logs(self, **kwargs): # pass # # def find_validation_logs(self, **kwargs): # pass # # def find_testing_logs(self, **kwargs): # pass # =========================== Task ===================================
[docs] def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_result_keys=None, **kwargs): """Uploads a task to the database, timestamp will be added automatically. Parameters ----------- task_name : str The task name. script : str File name of the python script. hyper_parameters : dictionary The hyper parameters pass into the script. saved_result_keys : list of str The keys of the task results to keep in the database when the task finishes. kwargs : other parameters Users customized parameters such as description, version number. Examples ----------- Uploads a task >>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial') Finds and runs the latest task >>> db.run_top_task(sort=[("time", pymongo.DESCENDING)]) >>> db.run_top_task(sort=[("time", -1)]) Finds and runs the oldest task >>> db.run_top_task(sort=[("time", pymongo.ASCENDING)]) >>> db.run_top_task(sort=[("time", 1)]) """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") if not isinstance(script, str): # is None: raise Exception("script should be string") if hyper_parameters is None: hyper_parameters = {} if saved_result_keys is None: saved_result_keys = [] self._fill_project_info(kwargs) kwargs.update({'time': datetime.utcnow()}) kwargs.update({'hyper_parameters': hyper_parameters}) kwargs.update({'saved_result_keys': saved_result_keys}) _script = open(script, 'rb').read() kwargs.update({'status': 'pending', 'script': _script, 'result': {}}) self.db.Task.insert_one(kwargs) logging.info("[Database] Saved Task - task_name: {} script: {}".format(task_name, script))
[docs] def run_top_task(self, task_name=None, sort=None, **kwargs): """Finds and runs a pending task that in the first of the sorting list. Parameters ----------- task_name : str The task name. sort : List of tuple PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details. kwargs : other parameters Users customized parameters such as description, version number. Examples --------- Monitors the database and pull tasks to run >>> while True: >>> print("waiting task from distributor") >>> db.run_top_task(task_name='mnist', sort=[("time", -1)]) >>> time.sleep(1) Returns -------- boolean : True for success, False for fail. """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") self._fill_project_info(kwargs) kwargs.update({'status': 'pending'}) # find task and set status to running task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort) # try: # get task info e.g. hyper parameters, python script if task is None: logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort)) return False else: logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort)) _datetime = task['time'] _script = task['script'] _id = task['_id'] _hyper_parameters = task['hyper_parameters'] _saved_result_keys = task['saved_result_keys'] logging.info(" hyper parameters:") for key in _hyper_parameters: globals()[key] = _hyper_parameters[key] logging.info(" {}: {}".format(key, _hyper_parameters[key])) # run task s = time.time() logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime)) _script = _script.decode('utf-8') with tf.Graph().as_default(): # # as graph: # clear all TF graphs exec(_script, globals()) # set status to finished _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}}) # return results __result = {} for _key in _saved_result_keys: logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key]))) __result.update({"%s" % _key: globals()[_key]}) _ = self.db.Task.find_one_and_update( {'_id': _id}, {'$set': { 'result': __result }}, return_document=pymongo.ReturnDocument.AFTER ) logging.info( "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format( task_name, sort, _datetime, time.time() - s ) ) return True
# except Exception as e: # exc_type, exc_obj, exc_tb = sys.exc_info() # fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] # logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e)) # logging.info("[Database] Fail to run task") # # if fail, set status back to pending # _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}}) # return False
[docs] def delete_tasks(self, **kwargs): """Delete tasks. Parameters ----------- kwargs : logging information Find items to delete, leave it empty to delete all log. Examples --------- >>> db.delete_tasks() """ self._fill_project_info(kwargs) self.db.Task.delete_many(kwargs) logging.info("[Database] Delete Task SUCCESS")
[docs] def check_unfinished_task(self, task_name=None, **kwargs): """Finds and runs a pending task. Parameters ----------- task_name : str The task name. kwargs : other parameters Users customized parameters such as description, version number. Examples --------- Wait until all tasks finish in user's local console >>> while not db.check_unfinished_task(): >>> time.sleep(1) >>> print("all tasks finished") >>> sess = tf.InteractiveSession() >>> net = db.find_top_model(sess=sess, sort=[("test_accuracy", -1)]) >>> print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name)) Returns -------- boolean : True for success, False for fail. """ if not isinstance(task_name, str): # is None: raise Exception("task_name should be string") self._fill_project_info(kwargs) kwargs.update({'$or': [{'status': 'pending'}, {'status': 'running'}]}) # ## find task # task = self.db.Task.find_one(kwargs) task = self.db.Task.find(kwargs) task_id_list = task.distinct('_id') n_task = len(task_id_list) if n_task == 0: logging.info("[Database] No unfinished task - task_name: {}".format(task_name)) return False else: logging.info("[Database] Find {} unfinished task - task_name: {}".format(n_task, task_name)) return True
@staticmethod def _print_dict(args): string = '' for key, value in args.items(): if key is not '_id': string += str(key) + ": " + str(value) + " / " return string