diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 77fa70f235..e5dcd0ddfe 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -50,6 +50,11 @@ logger = logging.getLogger('gensim.models.ldamodel') +# DD +USE_DASK = True +logging.basicConfig( + format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) + DTYPE_TO_EPS = { np.float16: 1e-5, np.float32: 1e-35, @@ -182,6 +187,366 @@ def load(cls, fname, *args, **kwargs): return result # endclass LdaState +# DD +# wrapper to just return lda.state after performing lda +from distributed import local_client + + +class DaskWorker(object): + #from enum import Enum + #class MsgType(Enum): + RESET = 'RESET' + EVAL = 'EVAL' + GETSTATE ='GETSTATE' + DONE = 'DONE' + + + def __init__(self, state,inqueue,outqueue,myid, id2word, num_topics, chunksize, alpha, eta, distributed, random_state=None): + logger.info("DaskWorker init") + # with local_client() as c: + # Make model + self.jobsdone = 0 # how many jobs has this worker completed? + self.myid = myid # id of this worker in the dispatcher; just a convenience var for easy access/logging TODO remove? + #self.finished = False + self.model = LdaModel(id2word=id2word, num_topics=num_topics, chunksize=chunksize, + alpha=alpha, eta=eta, distributed=distributed, random_state=random_state) + + # enter event looop + """ + Request jobs from the dispatcher, in a perpetual loop until `getstate()` is called. + """ + from distributed import worker_client, Variable,Queue, secede + print("one") + with worker_client(separate_thread=False) as c: + secede() + self.state = Variable(name=state,client=c) + self.inqueue = Queue(name=inqueue,client=c) + self.outqueue = Queue(name=outqueue,client=c) + + lastMsgType = None + while True: + import tornado + try: + print("two") + msg = self.inqueue.get(timeout=5) + #msg = self.inqueue.get() + print("two-two") + import pickle + msg = pickle.loads(msg) + logger.info("worker #%s received job #%i", self.myid, self.jobsdone) + self.processjob(msg) + print("twotwo") + # self.dispktcher.jobdone(self.myid) + except tornado.util.TimeoutError as e: + print("three-three") + pass + except: + print("three") + pass + print("four") + msg = self.state.get() # blocking + msgType = msg[0] + print(f"five {msgType}") + logger.debug(f"id {self.myid} msg {msgType}") + if msgType == DaskWorker.DONE: + self.outqueue.put(None) + return + elif msgType == DaskWorker.EVAL: + continue + elif msgType == DaskWorker.RESET: + if msgType == lastMsgType: + continue + data = msg[1] + import pickle + data = pickle.loads(data) + + self.reset(data) + self.outqueue.put(None) + elif msgType == DaskWorker.GETSTATE: + if msgType == lastMsgType: + continue + ldastate = self.getstate() + import pickle + ldastate = pickle.dumps(ldastate) + print("five") + self.outqueue.put(ldastate) + else: + assert False, f'Unknown job {msgType}' + lastMsgType = msgType + + def processjob(self, job): + logger.debug("starting to process job #%i", self.jobsdone) + if self.model is None: + raise RuntimeError("worker must be initialized before receiving jobs") + self.model.do_estep(job) + self.jobsdone += 1 + # if SAVE_DEBUG and self.jobsdone % SAVE_DEBUG == 0: + # fname = os.path.join(tempfile.gettempdir(), 'lda_worker.pkl') + # self.model.save(fname) + logger.info("finished processing job #%i", self.jobsdone - 1) + + def getstate(self): + logger.info("worker #%i returning its state after %s jobs", + self.myid, self.jobsdone) + result = self.model.state + assert isinstance(result, LdaState) + self.model.clear() # free up mem in-between two EM cycles + #self.finished = True + return result + + def reset(self, state): + assert state is not None + logger.info("resetting worker #%i", self.myid) + self.model.state = state + self.model.sync_state() + self.model.state.reset() + #self.finished = False + + +class DaskDispatcher(object): + """ + Dispatcher object that communicates and coordinates individual workers. + + There should never be more than one dispatcher running at any one time. + """ + MAX_JOBS_QUEUE = 10 + HUGE_TIMEOUT = 365 * 24 * 60 * 60 # one year + + def __init__(self, maxsize=MAX_JOBS_QUEUE, ns_conf=None): + """ + Note that the constructor does not fully initialize the dispatcher; + use the `initialize()` function to populate it with workers etc. + """ + self.maxsize = maxsize + # self.callback = None # a pyro proxy to this object (unknown at init time, but will be set later) + self.ns_conf = ns_conf if ns_conf is not None else {} + + # def initialize(self, **model_params): + def initialize(self, id2word, num_topics, chunksize, alpha, eta, distributed, random_state=None): + """ + `model_params` are parameters used to initialize individual workers (gets + handed all the way down to `worker.initialize()`). + """ + #self.lock_update = threading.Lock() + self._jobsdone = 0 + self._jobsreceived = 0 + + # locate all available workers and store their proxies, for subsequent RMI calls + self.workers = {} + from distributed.security import Security + sec = Security() + + from dask.distributed import Client, Variable, Queue + self.dispatcher = Client(name="TME", security=sec) + + #self.numworkers = sum(self.dispatcher.ncores().values()) + #self.workerin = Queue(client=self.dispatcher,maxsize=self.maxsize) + #self.workerout = Queue(client=self.dispatcher,maxsize=self.maxsize) + #self.state = Variable(client=self.dispatcher) + self.workerin = Queue() + self.workerout = Queue() + self.state = Variable() + + #info = yield self.dispatcher.scheduler.identity() + info = self.dispatcher.scheduler_info() + assert info['workers'] + self.workers = {id: self.dispatcher.submit(DaskWorker, state=self.state.name,inqueue=self.workerin.name, outqueue=self.workerout.name,myid=id, id2word=id2word, num_topics=num_topics, chunksize=chunksize, + alpha=alpha, eta=eta, distributed=False, random_state=random_state, pure=False + #, workers=[w]) + # + ) for id, (w,_) in enumerate(info['workers'].items()) } + + if not self.workers: + raise RuntimeError( + 'no workers found; run some lda_worker scripts on your machines first!') + for id,w in self.workers.items(): + assert w.status != 'error' + + def getworkers(self): + """ + Return pyro URIs of all registered workers. + """ + return list(self.workers.keys()) + + ''' + def getjob(self, worker_id): + logger.info("worker #%i requesting a new job", worker_id) + job = self.jobs.get(block=True, timeout=1) + logger.info("worker #%i got a new job (%i left)", + worker_id, self.jobs.qsize()) + return job + ''' + + def putjob(self, job): + self._jobsreceived += 1 + self.state.set([DaskWorker.EVAL]) + import pickle + job = pickle.dumps(job) + self.workerin.put(job, timeout=DaskDispatcher.HUGE_TIMEOUT) + logger.info("added a new job (len(queue)=%i items)", self.workerin.qsize()) + + def getstate(self): + """ + Merge states from across all workers and return the result. + """ + logger.info("end of input, assigning all remaining jobs") + logger.debug("jobs done: %s, jobs received: %s", + self._jobsdone, self._jobsreceived) + """ + while self._jobsdone < self._jobsreceived: + me.sleep(0.5) # check e/fouvery half a second + + logger.info("merging states from %i workers", len(self.workers)) + workers = list(self.workers.values()) + result = workers[0].getstate() + for worker in workers[1:]: + result.merge(worker.getstate()) + """ + self.state.set([DaskWorker.GETSTATE]) + assert self.workerin.qsize() == 0 + result = None + for _ in range(len(self.workers)): + state = self.workerout.get() + import pickle + state = pickle.loads(state) + if result is None: + result = state + else: + result.merge(state) + assert self.workerout.qsize() == 0 + self.state.set([DaskWorker.EVAL]) + logger.info("sending out merged state") + return result + + def reset(self, state): + """ Initialize all workers for a new EM iterations. """ + import pickle + pickled = pickle.dumps(state) + self.state.set([DaskWorker.RESET,pickled]) + assert self.workerin.qsize() == 0 + for _ in range(len(self.workers)): + workerstate = self.workerout.get() + assert workerstate is None + #logger.info("resetting worker %s", workerid) + #worker.reset(state) + #worker.requestjob() + assert self.workerout.qsize() == 0 + self.state.set([DaskWorker.EVAL]) + self._jobsdone = 0 + self._jobsreceived = 0 + def __del__(self): + """ + Initialize all workers for a new EM iterations. + """ + self.state.set([DaskWorker.DONE]) + assert self.workerin.qsize() == 0 + for _ in range(len(self.workers)): + state = self.workerout.get() + assert state is None + #logger.info("resetting worker %s", workerid) + #worker.reset(state) + #worker.requestjob() + assert self.workerout.qsize() == 0 + self.dispatcher.gather( self.workers) + +# endclass Dispatcher +def DaskWorker2( job, state, id2word, num_topics, chunksize, alpha, eta, distributed, random_state=None): + model = LdaModel(id2word=id2word, num_topics=num_topics, chunksize=chunksize, + alpha=alpha, eta=eta, distributed=distributed, random_state=random_state) + + if state: + model.state = state + model.sync_state() + model.state.reset() + + #self.finished = False + model.do_estep(job) + result = model.state + assert isinstance(result, LdaState) + model.clear() # free up mem in-between two EM cycles + return result + +class DaskDispatcher2(object): + """ + Dispatcher object that communicates and coordinates individual workers. + + There should never be more than one dispatcher running at any one time. + """ + + # def initialize(self, **model_params): + def initialize(self, id2word, num_topics, chunksize, alpha, eta, distributed, random_state=None): + """ + `model_params` are parameters used to initialize individual workers (gets + handed all the way down to `worker.initialize()`). + """ + self.id2word = id2word + self.num_topics = num_topics + self.chunksize = chunksize + self.alpha = alpha + self.eta = eta + self.distributed = distributed + self.random_state = random_state + + + # locate all available workers and store their proxies, for subsequent RMI calls + from distributed.security import Security + from distributed import Client + sec = Security() + + self.dispatcher = Client(name="TME", security=sec) + self.id2word = self.dispatcher.scatter(id2word,broadcast=True) + + + #info = yield self.dispatcher.scheduler.identity() + info = self.dispatcher.scheduler_info() + self.workers = info['workers'] + self.dask = [] + self.state = None + if not self.workers or len(self.workers) == 0: + raise RuntimeError( + 'no workers found; run some lda_worker scripts on your machines first!') + + def getworkers(self): + """ + Return pyro URIs of all registered workers. + """ + count = sum( max(1,v['ncores'] //2) for v in self.workers.values()) + return list(range( count ) ) + + + def putjob(self, job): + task = self.dispatcher.submit( + DaskWorker2, job=job, state=self.state, id2word=self.id2word, num_topics=self.num_topics, + chunksize=self.chunksize, alpha=self.alpha, eta=self.eta, distributed=self.distributed, random_state=self.random_state, + pure=False) + assert task.status != 'error' + self.dask += [task] + logger.info("added a new job (len(queue)=%i items)", len(self.dask) ) + + def getstate(self): + """ + Merge states from across all workers and return the result. + """ + assert len(self.dask) > 0 + states = self.dispatcher.gather(self.dask) + self.dask = [] + result = states[0] + for state in states[1:]: + result.merge(state) + return result + def reset(self,state): + self.state = state + self.state = self.dispatcher.scatter(state,broadcast=True) + """ + def __del__(self): + if self.dispatcher: + self.dispatcher.gather( self.dask) + del self.dispatcher + if self.dask: + del self.dask + """ +# endclass Dispatcher + class LdaModel(interfaces.TransformationABC, basemodel.BaseTopicModel): """ @@ -337,24 +702,36 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, if self.optimize_alpha: raise NotImplementedError("auto-optimizing alpha not implemented in distributed LDA") # set up distributed version - try: - import Pyro4 - if ns_conf is None: - ns_conf = {} - - with utils.getNS(**ns_conf) as ns: - from gensim.models.lda_dispatcher import LDA_DISPATCHER_PREFIX - self.dispatcher = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX]) - logger.debug("looking for dispatcher at %s" % str(self.dispatcher._pyroUri)) - self.dispatcher.initialize( - id2word=self.id2word, num_topics=self.num_topics, chunksize=chunksize, - alpha=alpha, eta=eta, distributed=False - ) - self.numworkers = len(self.dispatcher.getworkers()) - logger.info("using distributed version with %i workers", self.numworkers) - except Exception as err: - logger.error("failed to initialize distributed LDA (%s)", err) - raise RuntimeError("failed to initialize distributed LDA (%s)" % err) + # DD + if USE_DASK: + self.dispatcher = DaskDispatcher2() + self.dispatcher.initialize( + id2word=self.id2word, num_topics=self.num_topics, chunksize=chunksize, + alpha=alpha, eta=eta, distributed=False + ) + self.numworkers = len(self.dispatcher.getworkers()) + logger.info( + "using distributed version with %i workers", self.numworkers) + #self.numworkers = sum(self.dispatcher.ncores().values()) + else: + try: + import Pyro4 + if ns_conf is None: + ns_conf = {} + + with utils.getNS(**ns_conf) as ns: + from gensim.models.lda_dispatcher import LDA_DISPATCHER_PREFIX + self.dispatcher = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX]) + logger.debug("looking for dispatcher at %s" % str(self.dispatcher._pyroUri)) + self.dispatcher.initialize( + id2word=self.id2word, num_topics=self.num_topics, chunksize=chunksize, + alpha=alpha, eta=eta, distributed=False + ) + self.numworkers = len(self.dispatcher.getworkers()) + logger.info("using distributed version with %i workers", self.numworkers) + except Exception as err: + logger.error("failed to initialize distributed LDA (%s)", err) + raise RuntimeError("failed to initialize distributed LDA (%s)" % err) # Initialize the variational distribution q(beta|lambda) self.state = LdaState(self.eta, (self.num_topics, self.num_terms), dtype=self.dtype) @@ -698,8 +1075,7 @@ def rho(): dirty = False reallen = 0 - chunks = utils.grouper( corpus, chunksize, as_numpy=chunks_as_numpy, dtype=self.dtype) - for chunk_no, chunk in enumerate(chunks): + for chunk_no, chunk in enumerate(utils.grouper(corpus, chunksize, as_numpy=chunks_as_numpy)): reallen += len(chunk) # keep track of how many documents we've processed so far if eval_every and ((reallen == lencorpus) or ((chunk_no + 1) % (eval_every * self.numworkers) == 0)): diff --git a/gensim/utils.py b/gensim/utils.py index 6d2823c652..f6e5c4fdf3 100644 --- a/gensim/utils.py +++ b/gensim/utils.py @@ -1119,7 +1119,7 @@ def substitute_entity(match): return RE_HTML_ENTITY.sub(substitute_entity, text) -def chunkize_serial(iterable, chunksize, as_numpy=False, dtype=np.float32): +def chunkize_serial(iterable, chunksize, as_numpy=False): """Give elements from the iterable in `chunksize`-ed lists. The last returned element may be smaller (if length of collection is not divisible by `chunksize`). @@ -1148,7 +1148,7 @@ def chunkize_serial(iterable, chunksize, as_numpy=False, dtype=np.float32): if as_numpy: # convert each document to a 2d numpy array (~6x faster when transmitting # chunk data over the wire, in Pyro) - wrapped_chunk = [[np.array(doc, dtype=dtype) for doc in itertools.islice(it, int(chunksize))]] + wrapped_chunk = [[np.array(doc) for doc in itertools.islice(it, int(chunksize))]] else: wrapped_chunk = [list(itertools.islice(it, int(chunksize)))] if not wrapped_chunk[0]: