Rewrite of the config class to support config overload
authorFrederic Massart <fred@moodle.com>
Wed, 20 Feb 2013 07:47:02 +0000 (15:47 +0800)
committerFrederic Massart <fred@moodle.com>
Wed, 20 Feb 2013 07:47:02 +0000 (15:47 +0800)
lib/config.py

index 8d96cb2..84d223f 100644 (file)
@@ -28,30 +28,23 @@ import re
 from exceptions import ConfigFileCouldNotBeLoaded, ConfigFileNotFound, ConfigFileCouldNotBeSaved
 
 
-class Config(object):
-    """Generic config class"""
-
-    directories = None
-    filename = 'config.json'
+class ConfigObject(object):
+    """Configuration object"""
     data = None
-    configfile = None
 
-    def __init__(self, path=None, filename=None):
-        """Creates the configuration object"""
-        self.directories = []
-        if path != None:
-            self.directories.insert(0, path)
-        if filename != None:
-            self.filename = filename
+    def __init__(self):
+        self.data = {}
 
     def add(self, name, value):
-        """Add a new config to the config file"""
+        """Add a new config but throws an exception if already defined"""
         if self.get(name) != None:
             raise Exception('Setting already declared')
         self.set(name, value)
 
-    def get(self, name=None):
-        """Return a setting or None if not found"""
+    def get(self, name=None, default=None):
+        """Return all the settings, or the setting if name is specified.
+        In case the setting is not found default is returned instead.
+        """
         if name == None:
             return self.data
         name = unicode(name).split('.')
@@ -60,28 +53,55 @@ class Config(object):
             try:
                 data = data[n]
             except:
-                data = None
+                data = default
                 break
         return data
 
-    def load(self, fn=None):
-        """Loads the configuration from the config file"""
-        if fn == None:
-            fn = self.resolve()
-        self.configfile = fn
+    def getFlat(self, data=None, parent=''):
+        """Return the entire data as a flat array"""
+        flatten = {}
+        if data == None:
+            data = self.get()
+        for k, v in data.items():
+            newKey = '%s.%s' % (parent, k) if parent != '' else k
+            if type(v) == dict:
+                for k2, v2 in self.getFlat(v, newKey).items():
+                    flatten[k2] = v2
+            else:
+                flatten[newKey] = v
+        return flatten
+
+    def load(self, data, merge=False):
+        """Load up the data"""
+        if merge:
+            data = self.mergeData(self.data, data)
+        self.data = data
+
+    def loadFromFile(self, filepath, merge=False):
+        """Load the settings from a file"""
+        if not os.path.isfile(filepath):
+            raise ConfigFileNotFound('Could not find the config file %s' % filepath)
         try:
             lines = ''
-            f = open(fn, 'r')
+            f = open(filepath, 'r')
             for l in f:
                 if re.match(r'^\s*//', l):
                     continue
                 lines += l
-            self.data = {}
             if len(lines) > 0:
-                self.data = json.loads(lines)
+                self.load(json.loads(lines), merge=merge)
             f.close()
         except:
-            raise ConfigFileCouldNotBeLoaded('Could not load config file %s' % fn)
+            raise ConfigFileCouldNotBeLoaded('Could not load config file %s' % filepath)
+
+    def mergeData(self, origData, newData):
+        """Recursively merge 2 dict of data"""
+        for k, v in newData.items():
+            if k in origData and type(v) == dict:
+                origData[k] = self.mergeData(origData[k], v)
+            else:
+                origData[k] = v
+        return origData
 
     def remove(self, name):
         """Remove a setting"""
@@ -103,27 +123,15 @@ class Config(object):
                     break
         self.save()
 
-    def resolve(self):
-        """Resolve the path to the configuration file"""
-        path = None
-        for directory in self.directories:
-            candidate = os.path.expanduser(os.path.join(directory, self.filename))
-            if os.path.isfile(candidate):
-                path = candidate
-                break
-        if path == None:
-            raise ConfigFileNotFound('Could not find config file')
-        return path
-
-    def save(self):
+    def save(self, filepath):
         """Save the settings to the config file"""
         try:
-            f = open(self.configfile, 'w')
-            json.dump(self.data, f, indent=4)
+            f = open(filepath, 'w')
+            json.dump(self.get(), f, indent=4)
             f.close()
         except Exception as e:
             print e
-            raise ConfigFileCouldNotBeSaved('Could not save to config file %s' % self.configfile)
+            raise ConfigFileCouldNotBeSaved('Could not save to config file %s' % filepath)
 
     def set(self, name, value):
         """Set a new setting"""
@@ -142,15 +150,136 @@ class Config(object):
                 except:
                     data[n] = {}
                     data = data[n]
-        self.save()
+
+
+class Config(object):
+    """Generic config class"""
+
+    files = None
+    _loaded = False
+
+    # ConfigObject storing a merge of all the config files
+    data = None
+
+    # ConfigObject for each config file
+    objects = None
+
+    def __init__(self, files=[]):
+        """Creates the configuration object"""
+        self.files = []
+        for f in files:
+            self.files.append(f)
+        self.data = ConfigObject()
+        self.objects = {}
+
+    def add(self, name, value):
+        """Add a new config"""
+        self.data.add(name, value)
+
+    def get(self, name=None):
+        """Return a setting"""
+        return self.data.get(name)
+
+    def load(self, allowMissing=False):
+        """Loads the configuration from the config files"""
+
+        if self._loaded:
+            return True
+
+        for fn in self.files:
+            self.objects[fn] = ConfigObject()
+            try:
+                self.objects[fn].loadFromFile(fn)
+            except ConfigFileNotFound as e:
+                if not allowMissing:
+                    raise e
+            self.data.load(self.objects[fn].get(), merge=True)
+
+    def reload(self):
+        """Reload the configuration"""
+        self._loaded = False
+        self.load()
+
+    def remove(self, name):
+        """Remove a setting"""
+        self.data.remove(name)
+
+    def save(self, to, confObj=None):
+        """Save the settings to the config file"""
+        if not confObj:
+            confObj = self.data
+        try:
+            f = open(to, 'w')
+            json.dump(confObj.get(), f, indent=4)
+            f.close()
+        except Exception as e:
+            print e
+            raise ConfigFileCouldNotBeSaved('Could not save to config file %s' % to)
+
+    def set(self, name, value):
+        """Set a new setting"""
+        self.data.set(name, value)
 
 
 class Conf(Config):
     """MDK config class"""
 
-    def __init__(self, path=None, filename=None):
-        Config.__init__(self, path, filename)
-        self.directories.append('~/.moodle-sdk/')
-        self.directories.append('/etc/moodle-sdk/')
-        self.directories.append(os.path.join(os.path.dirname(__file__), '..'))
-        self.load()
+    userFile = None
+
+    def __init__(self):
+        self.userFile = os.path.expanduser('~/.moodle-sdk/config.json')
+        files = [
+            os.path.join(os.path.dirname(__file__), '..', 'config-dist.json'),
+            os.path.join(os.path.dirname(__file__), '..', 'config.json'),
+            '/etc/moodle-sdk/config.json',
+            self.userFile,
+        ]
+        Config.__init__(self, files)
+        self.load(allowMissing=True)
+
+    def save(self, to=None, confObj=None):
+        """Save only the difference to the user config file"""
+
+        # The base file to use is the user file
+        to = self.userFile
+        diffData = self.objects[self.userFile]
+
+        files = list(self.files)
+        files.reverse()
+
+        # Each of the know settings will be checked
+        data = self.data.getFlat()
+        for k in sorted(data.keys()):
+            v = data[k]
+            different = False
+            found = False
+
+            # Respect the files order when reading the settings
+            for f in files:
+                o = self.objects[f]
+                ov = o.get(k)
+
+                # The value hasn't been found and is different
+                if not found and ov != None and ov != v:
+                    different = True
+                    break
+                # The value is set
+                elif ov != None and ov == v:
+                    found = True
+
+            # The value differs, or none of the file define it
+            if different or not found:
+                diffData.set(k, v)
+
+        confObj = diffData
+        super(Conf, self).save(to, confObj)
+
+    def remove(self, name):
+        """Remove a setting"""
+        super(Conf, self).remove(name)
+        self.save()
+
+    def set(self, name, value):
+        """Set a new setting"""
+        super(Conf, self).set(name, value)
+        self.save()