#! /usr/bin/python
# -*- coding: utf-8 -*-
import inspect
import pickle
import time
import uuid
from datetime import datetime
try:
import gridfs
from pymongo import MongoClient
except ImportError:
install_instr = "Please make sure you install PyMongo with the command: pip install pymongo."
raise ImportError("__init__.py : Could not import PyMongo." + install_instr)
def AutoFill(func):
def func_wrapper(self, *args, **kwargs):
d = inspect.getcallargs(func, self, *args, **kwargs)
d['args'].update({"studyID": self.studyID})
return func(**d)
return func_wrapper
[docs]class TensorDB(object):
"""TensorDB is a MongoDB based manager that help you to manage data, network topology, parameters and logging.
Parameters
-------------
ip : str
Localhost or IP address.
port : int
Port number.
db_name : str
Database name.
user_name : str
User name. Set to None if it donnot need authentication.
password : str
Password
Attributes
------------
db : ``pymongo.MongoClient[db_name]``, xxxxxx
datafs : ``gridfs.GridFS(self.db, collection="datafs")``, xxxxxxxxxx
modelfs : ``gridfs.GridFS(self.db, collection="modelfs")``,
paramsfs : ``gridfs.GridFS(self.db, collection="paramsfs")``,
db.Params : Collection for
db.TrainLog : Collection for
db.ValidLog : Collection for
db.TestLog : Collection for
studyID : string, unique ID, if None random generate one.
Notes
-------------
- MongoDB, as TensorDB is based on MongoDB, you need to install it in your local machine or remote machine.
- pip install pymongo, for MongoDB python API.
- You may like to install MongoChef or Mongo Management Studo APP for visualizing or testing your MongoDB.
"""
def __init__(
self, ip='localhost', port=27017, db_name='db_name', user_name=None, password='password', studyID=None
):
## connect mongodb
client = MongoClient(ip, port)
self.db = client[db_name]
if user_name != None:
self.db.authenticate(user_name, password)
if studyID is None:
self.studyID = str(uuid.uuid1())
else:
self.studyID = studyID
## define file system (Buckets)
self.datafs = gridfs.GridFS(self.db, collection="datafs")
self.modelfs = gridfs.GridFS(self.db, collection="modelfs")
self.paramsfs = gridfs.GridFS(self.db, collection="paramsfs")
self.archfs = gridfs.GridFS(self.db, collection="ModelArchitecture")
##
print("[TensorDB] Connect SUCCESS {}:{} {} {} {}".format(ip, port, db_name, user_name, studyID))
self.ip = ip
self.port = port
self.db_name = db_name
self.user_name = user_name
@classmethod
def __autofill(self, args):
return args.update({'studyID': self.studyID})
@staticmethod
def __serialization(ps):
return pickle.dumps(ps, protocol=2)
@staticmethod
def __deserialization(ps):
return pickle.loads(ps)
[docs] def save_params(self, params=None, args=None): #, file_name='parameters'):
""" Save parameters into MongoDB Buckets, and save the file ID into Params Collections.
Parameters
----------
params : a list of parameters
args : dictionary, item meta data.
Returns
---------
f_id : the Buckets ID of the parameters.
"""
if params is None:
params = []
if args is None:
args = {}
self.__autofill(args)
s = time.time()
f_id = self.paramsfs.put(self.__serialization(params)) #, file_name=file_name)
args.update({'f_id': f_id, 'time': datetime.utcnow()})
self.db.Params.insert_one(args)
# print("[TensorDB] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2)))
print("[TensorDB] Save params: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
return f_id
@AutoFill
def find_one_params(self, args=None, sort=None):
""" Find one parameter from MongoDB Buckets.
Parameters
----------
args : dictionary
For finding items.
Returns
--------
params : the parameters, return False if nothing found.
f_id : the Buckets ID of the parameters, return False if nothing found.
"""
if args is None:
args = {}
s = time.time()
# print(args)
d = self.db.Params.find_one(filter=args, sort=sort)
if d is not None:
f_id = d['f_id']
else:
print("[TensorDB] FAIL! Cannot find: {}".format(args))
return False, False
try:
params = self.__deserialization(self.paramsfs.get(f_id).read())
print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time() - s, 2)))
return params, f_id
except Exception:
return False, False
@AutoFill
def find_all_params(self, args=None):
""" Find all parameter from MongoDB Buckets
Parameters
----------
args : dictionary, find items
Returns
--------
params : the parameters, return False if nothing found.
"""
if args is None:
args = {}
s = time.time()
pc = self.db.Params.find(args)
if pc is not None:
f_id_list = pc.distinct('f_id')
params = []
for f_id in f_id_list: # you may have multiple Buckets files
tmp = self.paramsfs.get(f_id).read()
params.append(self.__deserialization(tmp))
else:
print("[TensorDB] FAIL! Cannot find any: {}".format(args))
return False
print("[TensorDB] Find all params SUCCESS, took: {}s".format(round(time.time() - s, 2)))
return params
@AutoFill
def del_params(self, args=None):
""" Delete params in MongoDB uckets.
Parameters
-----------
args : dictionary, find items to delete, leave it empty to delete all parameters.
"""
if args is None:
args = {}
pc = self.db.Params.find(args)
f_id_list = pc.distinct('f_id')
# remove from Buckets
for f in f_id_list:
self.paramsfs.delete(f)
# remove from Collections
self.db.Params.remove(args)
print("[TensorDB] Delete params SUCCESS: {}".format(args))
@staticmethod
def _print_dict(args):
# return " / ".join(str(key) + ": "+ str(value) for key, value in args.items())
string = ''
for key, value in args.items():
if key is not '_id':
string += str(key) + ": " + str(value) + " / "
return string
## =========================== LOG =================================== ##
@AutoFill
def train_log(self, args=None):
"""Save the training log.
Parameters
-----------
args : dictionary, items to save.
Examples
---------
>>> db.train_log(time=time.time(), {'loss': loss, 'acc': acc})
"""
if args is None:
args = {}
_result = self.db.TrainLog.insert_one(args)
_log = self._print_dict(args)
#print("[TensorDB] TrainLog: " +_log)
return _result
@AutoFill
def del_train_log(self, args=None):
""" Delete train log.
Parameters
-----------
args : dictionary, find items to delete, leave it empty to delete all log.
"""
if args is None:
args = {}
self.db.TrainLog.delete_many(args)
print("[TensorDB] Delete TrainLog SUCCESS")
@AutoFill
def valid_log(self, args=None):
"""Save the validating log.
Parameters
-----------
args : dictionary, items to save.
Examples
---------
>>> db.valid_log(time=time.time(), {'loss': loss, 'acc': acc})
"""
if args is None:
args = {}
_result = self.db.ValidLog.insert_one(args)
# _log = "".join(str(key) + ": " + str(value) for key, value in args.items())
_log = self._print_dict(args)
print("[TensorDB] ValidLog: " + _log)
return _result
@AutoFill
def del_valid_log(self, args=None):
""" Delete validation log.
Parameters
-----------
args : dictionary, find items to delete, leave it empty to delete all log.
"""
if args is None:
args = {}
self.db.ValidLog.delete_many(args)
print("[TensorDB] Delete ValidLog SUCCESS")
@AutoFill
def test_log(self, args=None):
"""Save the testing log.
Parameters
-----------
args : dictionary, items to save.
Examples
---------
>>> db.test_log(time=time.time(), {'loss': loss, 'acc': acc})
"""
if args is None:
args = {}
_result = self.db.TestLog.insert_one(args)
# _log = "".join(str(key) + str(value) for key, value in args.items())
_log = self._print_dict(args)
print("[TensorDB] TestLog: " + _log)
return _result
@AutoFill
def del_test_log(self, args=None):
""" Delete test log.
Parameters
-----------
args : dictionary, find items to delete, leave it empty to delete all log.
"""
if args is None:
args = {}
self.db.TestLog.delete_many(args)
print("[TensorDB] Delete TestLog SUCCESS")
# =========================== Network Architecture ================== ##
@AutoFill
def save_model_architecture(self, s, args=None):
if args is None:
args = {}
self.__autofill(args)
fid = self.archfs.put(s, filename="modelarchitecture")
args.update({"fid": fid})
self.db.march.insert_one(args)
@AutoFill
def load_model_architecture(self, args=None):
if args is None:
args = {}
d = self.db.march.find_one(args)
if d is not None:
fid = d['fid']
print(d)
print(fid)
# "print find"
else:
print("[TensorDB] FAIL! Cannot find: {}".format(args))
print("no idtem")
return False, False
try:
archs = self.archfs.get(fid).read()
return archs, fid
except Exception as e:
print("exception")
print(e)
return False, False
@AutoFill
def save_job(self, script=None, args=None):
"""Save the job.
Parameters
-----------
script : a script file name or None.
args : dictionary, items to save.
Examples
---------
>>> # Save your job
>>> db.save_job('your_script.py', {'job_id': 1, 'learning_rate': 0.01, 'n_units': 100})
>>> # Run your job
>>> temp = db.find_one_job(args={'job_id': 1})
>>> print(temp['learning_rate'])
... 0.01
>>> import _your_script
... running your script
"""
if args is None:
args = {}
self.__autofill(args)
if script is not None:
_script = open(script, 'rb').read()
args.update({'script': _script, 'script_name': script})
# _result = self.db.Job.insert_one(args)
_result = self.db.Job.replace_one(args, args, upsert=True)
_log = self._print_dict(args)
print("[TensorDB] Save Job: script={}, args={}".format(script, args))
return _result
@AutoFill
def find_one_job(self, args=None):
""" Find one job from MongoDB Job Collections.
Parameters
----------
args : dictionary, find items.
Returns
--------
dictionary : contains all meta data and script.
"""
if args is None:
args = {}
temp = self.db.Job.find_one(args)
if temp is not None:
if 'script_name' in temp.keys():
f = open('_' + temp['script_name'], 'wb')
f.write(temp['script'])
f.close()
print("[TensorDB] Find Job: {}".format(args))
else:
print("[TensorDB] FAIL! Cannot find any: {}".format(args))
return False
return temp
def push_job(self, margs, wargs, dargs, epoch):
_ms, mid = self.load_model_architecture(margs)
_weight, wid = self.find_one_params(wargs)
args = {
"weight": wid,
"model": mid,
"dargs": dargs,
"epoch": epoch,
"time": datetime.utcnow(),
"Running": False
}
self.__autofill(args)
self.db.JOBS.insert_one(args)
def peek_job(self):
args = {'Running': False}
self.__autofill(args)
m = self.db.JOBS.find_one(args)
print(m)
if m is None:
return False
s = self.paramsfs.get(m['weight']).read()
w = self.__deserialization(s)
ach = self.archfs.get(m['model']).read()
return m['_id'], ach, w, m["dargs"], m['epoch']
def run_job(self, jid):
self.db.JOBS.find_one_and_update({'_id': jid}, {'$set': {'Running': True, "Since": datetime.utcnow()}})
def del_job(self, jid):
self.db.JOBS.find_one_and_update({'_id': jid}, {'$set': {'Running': True, "Finished": datetime.utcnow()}})
def __str__(self):
_s = "[TensorDB] Info:\n"
_t = _s + " " + str(self.db)
return _t