Functional request header rewriting
authorJakub Skoczen <jakub@indexdata.dk>
Wed, 1 May 2013 15:15:34 +0000 (17:15 +0200)
committerJakub Skoczen <jakub@indexdata.dk>
Wed, 1 May 2013 15:15:34 +0000 (17:15 +0200)
The regex gets scanned for named captures which are than indexed
and used in the request URL recipe (configurable).

src/test_filter_rewrite.cpp

index be6a071..40ab6ba 100644 (file)
@@ -27,7 +27,8 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 
 #define BOOST_REGEX_MATCH_EXTRA
 
-#include <boost/regex.hpp>
+#include <boost/xpressive/xpressive.hpp>
+#include <boost/lexical_cast.hpp>
 
 #define BOOST_AUTO_TEST_MAIN
 #define BOOST_TEST_DYN_LINK
@@ -35,6 +36,7 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 #include <boost/test/auto_unit_test.hpp>
 
 using namespace boost::unit_test;
+using namespace boost::xpressive;
 namespace mp = metaproxy_1;
 
 class FilterHeaderRewrite: public mp::filter::Base {
@@ -76,50 +78,157 @@ public:
     void rewrite_req_header(Z_HTTP_Header *header) const
     {
         //exec regex against value
-        boost::regex e(req_uri_pat, boost::regex::perl);
-        boost::smatch what;
+        sregex re = sregex::compile(req_uri_rx);
+        smatch what;
         std::string hvalue(header->value);
-        if(boost::regex_match(hvalue, what, e, boost::match_extra))
+        std::map<std::string, std::string> vars;
+        if (regex_match(hvalue, what, re))
         {
-            unsigned i, j;
-            std::cout << "** Match found **\n   Sub-Expressions:\n";
-            for(i = 0; i < what.size(); ++i)
-                std::cout << "      $" << i << " = \"" << what[i] << "\"\n";
-            std::cout << "   Captures:\n";
-            for(i = 0; i < what.size(); ++i)
+            unsigned i;
+            for (i = 1; i < what.size(); ++i)
             {
-                std::cout << "      $" << i << " = {";
-                for(j = 0; j < what.captures(i).size(); ++j)
-                {
-                    if(j)
-                        std::cout << ", ";
-                    else
-                        std::cout << " ";
-                    std::cout << "\"" << what.captures(i)[j] << "\"";
+                //check if the group is named
+                std::map<int, std::string>::const_iterator it
+                    = groups_by_num.find(i);
+                if (it != groups_by_num.end()) 
+                {   //it is
+                    std::string name = it->second;
+                    vars[name] = what[i];
                 }
-                std::cout << " }\n";
             }
+            //rewrite the header according to the recipe
+            std::string rvalue = sub_vars(req_uri_pat, vars);
+            std::cout << "Rewritten '"+hvalue+"' to '"+rvalue+"'\n";
         }
         else
         {
-            std::cout << "** No Match found **\n";
+            std::cout << "No match found in '" + hvalue + "'\n";
         }
-        //iteratate over named groups
-        //set the captured values in the map
-        //rewrite the header according to the hardcoded recipe
     };
+
+    static void parse_groups(const std::string & str,
+            std::map<int, std::string> & groups_bynum,
+            std::map<std::string, int> & groups_byname)
+    {
+       int gnum = 0;
+       bool esc = false;
+       for (int i = 0; i < str.size(); ++i)
+       {
+           if (!esc && str[i] == '\\')
+           {
+               esc = true;
+               continue;
+           }
+           if (!esc && str[i] == '(') //group starts
+           {
+               gnum++;
+               if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
+               {
+                   i++;
+                   if (i+1 < str.size() && str[i+1] == 'P') //optional, python
+                       i++;
+                   if (i+1 < str.size() && str[i+1] == '<') //named
+                   {
+                       i++;
+                       std::string gname;
+                       bool term = false;
+                       while (++i < str.size())
+                       {
+                           if (str[i] == '>') { term = true; break; }
+                           if (!isalnum(str[i])) 
+                               throw mp::filter::FilterException
+                                   ("Only alphanumeric chars allowed, found "
+                                    " in '" 
+                                    + str 
+                                    + "' at " 
+                                    + boost::lexical_cast<std::string>(i)); 
+                           gname += str[i];
+                       }
+                       if (!term)
+                           throw mp::filter::FilterException
+                               ("Unterminated group name '" + gname 
+                                + " in '" + str +"'");
+                      groups_bynum[gnum] = gname;
+                      groups_byname[gname] = gnum;
+                      std::cout << "Found named group '" << gname 
+                          << "' at $" << gnum << std::endl;
+                   }
+               }
+           }
+           esc = false;
+       }
+    }
+
+    static std::string sub_vars (const std::string & in, 
+            const std::map<std::string, std::string> & vars)
+    {
+        std::string out;
+        bool esc = false;
+        for (int i = 0; i < in.size(); ++i)
+        {
+            if (!esc && in[i] == '\\')
+            {
+                esc = true;
+                continue;
+            }
+            if (!esc && in[i] == '$') //var
+            {
+                if (i+1 < in.size() && in[i+1] == '{') //ref prefix
+                {
+                    ++i;
+                    std::string name;
+                    bool term = false;
+                    while (++i < in.size()) 
+                    {
+                        if (in[i] == '}') { term = true; break; }
+                        name += in[i];
+                    }
+                    if (!term) throw mp::filter::FilterException
+                        ("Unterminated var ref in '"+in+"' at "
+                         + boost::lexical_cast<std::string>(i));
+                    std::map<std::string, std::string>::const_iterator it
+                        = vars.find(name);
+                    if (it != vars.end())
+                        out += it->second;
+                }
+                else
+                {
+                    throw mp::filter::FilterException
+                        ("Malformed or trimmed var ref in '"
+                         +in+"' at "+boost::lexical_cast<std::string>(i)); 
+                }
+                continue;
+            }
+            //passthru
+            out += in[i];
+            esc = false;
+        }
+        return out;
+    }
     
-    void configure(const std::string & req_uri_pat, 
+    void configure(
+            const std::string & req_uri_rx, 
+            const std::string & req_uri_pat, 
+            const std::string & resp_uri_rx, 
             const std::string & resp_uri_pat) 
     {
+       this->req_uri_rx = req_uri_rx;
        this->req_uri_pat = req_uri_pat;
+       //pick up names
+       parse_groups(req_uri_rx, groups_by_num, groups_by_name);
+       this->resp_uri_rx = resp_uri_rx;
        this->resp_uri_pat = resp_uri_pat;
     };
 
 private:
     std::map<std::string, std::string> vars;
+    std::string req_uri_rx;
+    std::string resp_uri_rx;
     std::string req_uri_pat;
     std::string resp_uri_pat;
+    std::map<int, std::string> groups_by_num;
+    std::map<std::string, int> groups_by_name;
+
 };
 
 
@@ -141,7 +250,11 @@ BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
         mp::RouterChain router;
 
         FilterHeaderRewrite fhr;
-        fhr.configure(".*(localhost).*", ".*(localhost).*");
+        fhr.configure(
+                ".*?(?P<host>[^:]+):(?P<port>\\d+).*",
+                "http://${host}:${port}/somepath",
+                ".*(localhost).*",
+                "http:://g");
         mp::filter::HTTPClient hc;
         
         router.append(fhr);
@@ -168,7 +281,8 @@ BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
         BOOST_CHECK(hres);
 
     }
-    catch ( ... ) {
+    catch (std::exception & e) {
+        std::cout << e.what();
         BOOST_CHECK (false);
     }
 }