Use C++ strings and handle non-capturing groups
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 class FilterHeaderRewrite: public mp::filter::Base {
40 public:
41     void process(mp::Package & package) const {
42         Z_GDU *gdu = package.request().get();
43         //map of request/response vars
44         std::map<std::string, std::string> vars;
45         //we have an http req
46         if (gdu && gdu->which == Z_GDU_HTTP_Request)
47         {
48             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
49             mp::odr o;
50             //rewrite the request line
51             std::string path;
52             if (strstr(hreq->path, "http://") == hreq->path)
53             {
54                 std::cout << "Path in the method line is absolute, " 
55                     "possibly a proxy request\n";
56                 path += hreq->path;
57             }
58             else
59             {
60                path += z_HTTP_header_lookup(hreq->headers, "Host");
61                path += hreq->path; 
62             }
63             std::cout << "Proxy request URL is " << path << std::endl;
64             std::string npath = 
65                 search_replace(vars, path, req_uri_rx, req_uri_pat);
66             std::cout << "Resp request URL is " << npath << std::endl;
67             if (!npath.empty())
68                 hreq->path = odr_strdup(o, npath.c_str());
69             std::cout << ">> Request headers" << std::endl;
70             //iterate headers
71             for (Z_HTTP_Header *header = hreq->headers;
72                     header != 0; 
73                     header = header->next) 
74             {
75                 std::cout << header->name << ": " << header->value << std::endl;
76                 std::string out = search_replace(vars, 
77                         std::string(header->value), 
78                         req_uri_rx, req_uri_pat);
79                 if (!out.empty())
80                     header->value = odr_strdup(o, out.c_str());
81             }
82             package.request() = gdu;
83         }
84         package.move();
85         gdu = package.response().get();
86         if (gdu && gdu->which == Z_GDU_HTTP_Response)
87         {
88             std::cout << "<< Respose headers" << std::endl;
89             Z_HTTP_Response *hr = gdu->u.HTTP_Response;
90             mp::odr o;
91             //iterate headers
92             for (Z_HTTP_Header *header = hr->headers;
93                     header != 0; 
94                     header = header->next) 
95             {
96                 std::cout << header->name << ": " << header->value << std::endl;
97                 std::string out = search_replace(vars, 
98                         std::string(header->value), 
99                         resp_uri_rx, resp_uri_pat);
100                 if (!out.empty())
101                     header->value = odr_strdup(o, out.c_str());
102             }
103             package.response() = gdu;
104         }
105     };
106
107     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
108
109     const std::string search_replace(
110             std::map<std::string, std::string> & vars,
111             const std::string txt,
112             const std::string & uri_re,
113             const std::string & uri_pat) const
114     {
115         //exec regex against value
116         boost::regex re(uri_re);
117         boost::smatch what;
118         std::string::const_iterator start, end;
119         start = txt.begin();
120         end = txt.end();
121         std::string out;
122         while (regex_search(start, end, what, re)) //find next full match
123         {
124             unsigned i;
125             for (i = 1; i < what.size(); ++i)
126             {
127                 //check if the group is named
128                 std::map<int, std::string>::const_iterator it
129                     = groups_by_num.find(i);
130                 if (it != groups_by_num.end()) 
131                 {   //it is
132                     std::string name = it->second;
133                     vars[name] = what[i];
134                 }
135
136             }
137             //prepare replacement string
138             std::string rvalue = sub_vars(uri_pat, vars);
139             //rewrite value
140             std::string rhvalue = what.prefix().str() 
141                 + rvalue + what.suffix().str();
142             std::cout << "! Rewritten '"+what.str(0)+"' to '"+rvalue+"'\n";
143             out += rhvalue;
144             start = what[0].second; //move search forward
145         }
146         return out;
147     };
148
149     static void parse_groups(const std::string & str,
150             std::map<int, std::string> & groups_bynum,
151             std::map<std::string, int> & groups_byname)
152     {
153        int gnum = 0;
154        bool esc = false;
155        for (int i = 0; i < str.size(); ++i)
156        {
157            if (!esc && str[i] == '\\')
158            {
159                esc = true;
160                continue;
161            }
162            if (!esc && str[i] == '(') //group starts
163            {
164                gnum++;
165                if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
166                {
167                    i++;
168                    if (i+1 < str.size() && str[i+1] == ':') //non-capturing
169                    {
170                        if (gnum > 0) gnum--;
171                        i++;
172                        continue;
173                    }
174                    if (i+1 < str.size() && str[i+1] == 'P') //optional, python
175                        i++;
176                    if (i+1 < str.size() && str[i+1] == '<') //named
177                    {
178                        i++;
179                        std::string gname;
180                        bool term = false;
181                        while (++i < str.size())
182                        {
183                            if (str[i] == '>') { term = true; break; }
184                            if (!isalnum(str[i])) 
185                                throw mp::filter::FilterException
186                                    ("Only alphanumeric chars allowed, found "
187                                     " in '" 
188                                     + str 
189                                     + "' at " 
190                                     + boost::lexical_cast<std::string>(i)); 
191                            gname += str[i];
192                        }
193                        if (!term)
194                            throw mp::filter::FilterException
195                                ("Unterminated group name '" + gname 
196                                 + " in '" + str +"'");
197                       groups_bynum[gnum] = gname;
198                       groups_byname[gname] = gnum;
199                       std::cout << "Found named group '" << gname 
200                           << "' at $" << gnum << std::endl;
201                    }
202                }
203            }
204            esc = false;
205        }
206     }
207
208     static std::string sub_vars (const std::string & in, 
209             const std::map<std::string, std::string> & vars)
210     {
211         std::string out;
212         bool esc = false;
213         for (int i = 0; i < in.size(); ++i)
214         {
215             if (!esc && in[i] == '\\')
216             {
217                 esc = true;
218                 continue;
219             }
220             if (!esc && in[i] == '$') //var
221             {
222                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
223                 {
224                     ++i;
225                     std::string name;
226                     bool term = false;
227                     while (++i < in.size()) 
228                     {
229                         if (in[i] == '}') { term = true; break; }
230                         name += in[i];
231                     }
232                     if (!term) throw mp::filter::FilterException
233                         ("Unterminated var ref in '"+in+"' at "
234                          + boost::lexical_cast<std::string>(i));
235                     std::map<std::string, std::string>::const_iterator it
236                         = vars.find(name);
237                     if (it != vars.end())
238                         out += it->second;
239                 }
240                 else
241                 {
242                     throw mp::filter::FilterException
243                         ("Malformed or trimmed var ref in '"
244                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
245                 }
246                 continue;
247             }
248             //passthru
249             out += in[i];
250             esc = false;
251         }
252         return out;
253     }
254     
255     void configure(
256             const std::string & req_uri_rx, 
257             const std::string & req_uri_pat, 
258             const std::string & resp_uri_rx, 
259             const std::string & resp_uri_pat) 
260     {
261        this->req_uri_rx = req_uri_rx;
262        this->req_uri_pat = req_uri_pat;
263        //pick up names
264        parse_groups(req_uri_rx, groups_by_num, groups_by_name);
265        this->resp_uri_rx = resp_uri_rx;
266        this->resp_uri_pat = resp_uri_pat;
267     };
268
269 private:
270     std::map<std::string, std::string> vars;
271     std::string req_uri_rx;
272     std::string resp_uri_rx;
273     std::string req_uri_pat;
274     std::string resp_uri_pat;
275     std::map<int, std::string> groups_by_num;
276     std::map<std::string, int> groups_by_name;
277
278 };
279
280
281 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
282 {
283     try
284     {
285        FilterHeaderRewrite fhr;
286     }
287     catch ( ... ) {
288         BOOST_CHECK (false);
289     }
290 }
291
292 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
293 {
294     try
295     {
296         mp::RouterChain router;
297
298         FilterHeaderRewrite fhr;
299         fhr.configure(
300                 "(?:http\\:\\/\\/s?)?(?<host>[A-Za-z.]+):(?<port>\\d+)",
301                 "http://${host}:${port}/somepath",
302                 //rewrite connection close
303                 "close",
304                 "open for ${host}");
305
306         mp::filter::HTTPClient hc;
307         
308         router.append(fhr);
309         router.append(hc);
310
311         // create an http request
312         mp::Package pack;
313
314         mp::odr odr;
315         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
316             "http://localhost:80/~jakub/targetsite.php", 0, 1);
317
318         pack.request() = gdu_req;
319
320         //feed to the router
321         pack.router(router).move();
322
323         //analyze the response
324         Z_GDU *gdu_res = pack.response().get();
325         BOOST_CHECK(gdu_res);
326         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
327         
328         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
329         BOOST_CHECK(hres);
330
331     }
332     catch (std::exception & e) {
333         std::cout << e.what();
334         BOOST_CHECK (false);
335     }
336 }
337
338 /*
339  * Local variables:
340  * c-basic-offset: 4
341  * c-file-style: "Stroustrup"
342  * indent-tabs-mode: nil
343  * End:
344  * vim: shiftwidth=4 tabstop=8 expandtab
345  */
346