#!/usr/bin/python -tt # -*- coding: utf-8 -*- # # Interface for LDAP functions with some extensions # # Copyright (c) 2005 JAS # # Author: Petr Vokac # # $Id: CvutLdap.py,v 1.13 2005/06/26 23:12:29 vokac Exp $ # import sys, exceptions, pickle import ldap import logging import Config __version__ = "$Revision: 1.13 $" class CvutLdap: """ Base class for creating connection (anonymout/authenticated) to the CVUT LDAP data sources. """ class CvutLdapException(exceptions.Exception): """ Base exception class. """ def __init__(self, args = ""): exceptions.Exception.__init__(self, args) class ConnectionException(CvutLdapException): """ Exception during initializing LDAP connection. """ def __init__(self, args = ""): CvutLdap.CvutLdapException.__init__(self, args) class ReadDataException(CvutLdapException): """ Exception during reading serialized result data from stream. """ def __init__(self, args = ""): CvutLdap.CvutLdapException.__init__(self, args) class WriteDataException(CvutLdapException): """ Exception during writing serialized result data to the stream. """ def __init__(self, args = ""): CvutLdap.CvutLdapException.__init__(self, args) def __init__(self, config): """ Initialize instance with "config" described in Config.py """ logging.getLogger().debug("CvutLdap.__init__") self._id = None self._name = None self._servers = None self._base = None self._filter = None self._bind_name = None self._bind_pass = None self._conn = None self._searchRes = None self.setConfig(config) def __del__(self): """ Release connections. """ logging.getLogger().debug("CvutLdap.__del__") self.releaseConnection() def getConfig(self): """ Get actual configuration """ retVal = {} if not self._id == None: retVal["id"] = self._id if not self._name == None: retVal["name"] = self._name if not self._servers == None: retVal["servers"] = self._servers if not self._tls == None: retVal["tls"] = self._tls if not self._base == None: retVal["base"] = self._base if not self._filter == None: retVal["filter"] = self._filter if not self._bind_name == None: retVal["bind_name"] = self._bind_name if not self._bind_pass == None: retVal["bind_pass"] = self._bind_pass return retVal def setConfig(self, config): """ Set whole configuration """ if config.has_key("id"): self.setId(config["id"]) else: self.setId(None) if config.has_key("name"): self.setName(config["name"]) else: self.setName(None) if config.has_key("servers"): self.setServers(config["servers"]) else: self.setServers(None) if config.has_key("tls"): self.setTLS(config["tls"]) else: self.setTLS(None) if config.has_key("base"): self.setBase(config["base"]) else: self.setBase(None) if config.has_key("bind_name"): self.setBindName(config["bind_name"]) else: self.setBindName(None) if config.has_key("bind_pass"): self.setBindPass(config["bind_pass"]) else: self.setBindPass(None) def getId(self): """ Get id """ if self._id == None: return "" else: return self._id def setId(self, id): """ Set id """ self._id = id def getName(self): """ Get name """ if self._name == None: return "" else: return self._name def setName(self, name): """ Set name """ self._name = name def getServers(self): """ Get list of distionaries containing server info """ retVal = [] for server in self._servers: if not server.has_key("protocol"): server["protocol"] = "ldap" retVal.append(server) return retVal def getServersText(self): """ Get string for ldap.initialize operation """ return self._serversText(self.getServers()) def setServers(self, servers): self._servers = self._normServers(servers) def getTLS(self): if self._tls == None: return "false" else: return self._tls def setTLS(self, tls): self._tls = self._normTLS(tls) def _serversText(self, servers): retVal = [] for server in servers: if server.has_key("port"): retVal.append("%s://%s:%s" % (server["protocol"], server["hostname"], server["port"])) else: retVal.append("%s://%s" % (server["protocol"], server["hostname"])) return ",".join(retVal) def _normServers(self, servers): if type(servers) == type(""): servers = servers.split(",") if type(servers) != type([]): return None retVal = [] for server in servers: normServer = {} if type(server) == type({}): # e.g. { "hostname" : "...", "port" : "...", ... } if server.has_key("hostname"): normServer["hostname"] = server["hostname"] if server.has_key("port"): normServer["port"] = server["port"] if server.has_key("protocol"): normServer["protocol"] = server["protocol"] elif type(server) == type(""): # e.g. "ldap://localhost:389" ph = server.split("://", 1) if len(ph) == 2: normServer["protocol"] = ph[0] server = ph[1] if server.find(":") != -1: normServer["hostname"] = server[:server.rindex(":")] normServer["port"] = server[server.rindex(":")+1:] else: normServer["hostname"] = server if len(normServer) > 0: retVal.append(normServer) else: logging.getLogger().warning("CvutLdap._normServers: bad server configuration %s" % server) return retVal def _normTLS(self, tls): if tls == None: return None if tls.lower() == "on" or tls.lower() == "true" or tls.lower() == "1": return "true" else: return "false" def getBase(self): if self._base == None: return "" else: return self._base def setBase(self, base): self._base = base def getFilter(self): if self._filter == None: return "" else: return self._filter def setFilter(self, filter): self._filter = filter def getBindName(self): if self._bindName == None: return "" else: return self._bindName def setBindName(self, bindName): self._bindName = bindName def getBindPass(self): if self._bindPass == None: return "" else: return self._bindPass def setBindPass(self, bindPass): self._bindPass = bindPass def getConnection(self): #, servers = None, tls = None, user = None, passw = None): """ Get connection for specified LDAP server """ if self._conn != None: return self._conn try: logging.getLogger().debug("CvutLdap.getConnection: %s" % self.getServersText()) ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, Config.tls_ca_cert_file) self._conn = ldap.initialize(self.getServersText()) # ldap.protocol_version = ldap.VERSION3 if self.getTLS() == "true": # Now try StartTLS extended operation self._conn.set_option(ldap.OPT_X_TLS,ldap.OPT_X_TLS_DEMAND) self._conn.start_tls_s() self._conn.simple_bind_s(self.getBindName(), self.getBindPass()) return self._conn except ldap.LDAPError, e: logging.getLogger().error("CvutLdap.getConnection: %s" % str(e)) raise CvutLdap.ConnectionException(e) def search(self, base = None, scope = ldap.SCOPE_SUBTREE, searchFilter = None, attrs = None, aditive = False): """ Search in LDAP specified by defined configuration. If you invoke this method without calling clearSearchResult() than all result will be merged Attributes: base -- search base (default: SCOPE_SUBTREE) filter -- search filter (defautl: (objectclass=*)) attrs -- search attributes (default: "all") aditive -- add search result to previous (default: True) """ logging.getLogger().debug("CvutLdap.search(%s, %s, %s, %s, %s)" %(base, scope, searchFilter, attrs, aditive)) if not aditive: self.clear() if self._conn == None: self._conn = self.getConnection() if base == None: base = self._base if searchFilter == None or searchFilter == "": if self._filter == None: searchFilter = "(objectClass=*)" else: searchFilter = self._filter searchRes = self._conn.search_s(base, scope, searchFilter, attrs) if self._searchRes == None: self._searchRes = [] if len(self._searchRes) > 0: for dn, attrs in searchRes: self._searchRes.append((dn, attrs)) else: self._searchRes = searchRes del(searchRes) return self._searchRes def get(self): """ Return raw search result set. """ return self._searchRes def clear(self): """ Clear search result. """ logging.getLogger().debug("CvutLdap.clear()") if self._searchRes != None: del(self._searchRes) self._searchRes = None def read(self, inputStream = sys.stdin): """ Read search result that was stored in the file using standard Python serialization (Pickler). Attributes: inputStream -- opened file stream (for reading) """ logging.getLogger().debug("CvutLdap.read(%s)" % sys.stdout.name) try: self._searchRes = pickle.Unpickler(inputStream).load() except pickle.UnpicklingError, e: logging.getLogger().error("Error reading result data from the stream: %s" % str(e)) raise CvutLdap.ReadDataException(e) # TODO: handle those "unpickling" errors #except AttributeError, e: #except EOFError, e: #except ImportError, e: #except IndexError, e: return self._searchRes def write(self, outputStream = sys.stdout): """ Write search result in the file using standard Python serialization (Pickler). Attributes: outputStream -- opened file stream (for writing) """ logging.getLogger().debug("CvutLdap.write(%s)" % sys.stdout.name) try: # use optimal pickle data format, but "protocol" doesn't exist in earlyer versions if pickle.format_version >= '2.0': pickle.Pickler(outputStream, protocol=-1).dump(self._searchRes) else: pickle.Pickler(outputStream).dump(self._searchRes) except pickle.PicklingError, e: logging.getLogger().error("Error writing result data to the stream: %s" % str(e)) raise CvutLdap.WriteDataException(e) def releaseConnection(self): """ Release resources - search result and LDAP connection. """ self.clear() if self._conn != None: self._conn.unbind() self._conn = None if __name__ == "__main__": streamHandler = logging.StreamHandler() streamHandler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s](%(module)s:%(lineno)d) %(message)s", "%d %b %H:%M:%S")) logging.getLogger().addHandler(streamHandler) logging.getLogger().setLevel(logging.DEBUG) usermap = CvutLdap(Config.ldapUsermap) # print usermap.getConfig() fjfinds = CvutLdap(Config.ldapFjfiNds) # print fjfinds.getConfig() fjfiad = CvutLdap(Config.ldapFjfiAd) # print fjfiad.getConfig() fjfildap = CvutLdap(Config.ldapFjfiLdap) # print fjfildap.getConfig() test = CvutLdap({ "id" : "Testing id", "name" : "Testing name", "servers" : [ "ldaps://fjfi.cvut.cz", { "hostname" : "localhost" } ], "base" : "o=cvut,c=cz", "bind_name" : "Manager", "bind_pass" : "blabla" } ) print test.getName() print test.getServersText() print test.getConfig() usermapConn = usermap.getConnection() fjfindsConn = fjfinds.getConnection() del(usermap) del(fjfinds) fjfildap.search() print "result length: %s" % len(fjfildap.get()) fjfildap.write(open("CvutLdap.p", "wb")) print "result length: %s" % len(fjfildap.get()) fjfildap.clear() print "result length: %s" % fjfildap.get() fjfildap.read(open("CvutLdap.p", "rb")) print "result length: %s" % len(fjfildap.get()) #ldap_extended_operation() #ldap_parse_extended_result() print "Done."