Refactor
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 typedef std::pair<std::string, std::string> string_pair;
40 typedef std::vector<string_pair> spair_vec;
41 typedef spair_vec::iterator spv_iter;
42
43 class FilterHeaderRewrite: public mp::filter::Base {
44 public:
45     void process(mp::Package & package) const 
46     {
47         Z_GDU *gdu = package.request().get();
48         //map of request/response vars
49         std::map<std::string, std::string> vars;
50         //we have an http req
51         if (gdu && gdu->which == Z_GDU_HTTP_Request)
52         {
53             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
54             mp::odr o;
55             std::cout << ">> Request headers" << std::endl;
56             rewrite_reqline(o, hreq, vars);
57             rewrite_headers(o, hreq->headers, vars);
58             package.request() = gdu;
59         }
60         package.move();
61         gdu = package.response().get();
62         if (gdu && gdu->which == Z_GDU_HTTP_Response)
63         {
64             Z_HTTP_Response *hr = gdu->u.HTTP_Response;
65             std::cout << "Response " << hr->code;
66             std::cout << "<< Respose headers" << std::endl;
67             mp::odr o;
68             rewrite_headers(o, hr->headers, vars);
69             package.response() = gdu;
70         }
71     }
72
73     void rewrite_reqline (mp::odr & o, Z_HTTP_Request *hreq,
74             std::map<std::string, std::string> & vars) const 
75     {
76         //rewrite the request line
77         std::string path;
78         if (strstr(hreq->path, "http://") == hreq->path)
79         {
80             std::cout << "Path in the method line is absolute, " 
81                 "possibly a proxy request\n";
82             path += hreq->path;
83         }
84         else
85         {
86             //TODO what about proto
87             path += z_HTTP_header_lookup(hreq->headers, "Host");
88             path += hreq->path; 
89         }
90         std::cout << "Proxy request URL is " << path << std::endl;
91         std::string npath = 
92             test_patterns(vars, path, req_uri_pats, req_groups_bynum);
93         std::cout << "Resp request URL is " << npath << std::endl;
94         if (!npath.empty())
95             hreq->path = odr_strdup(o, npath.c_str());
96     }
97  
98     void rewrite_headers (mp::odr & o, Z_HTTP_Header *headers,
99             std::map<std::string, std::string> & vars) const 
100     {
101         for (Z_HTTP_Header *header = headers;
102                 header != 0; 
103                 header = header->next) 
104         {
105             std::string sheader(header->name);
106             sheader += ": ";
107             sheader += header->value;
108             std::cout << header->name << ": " << header->value << std::endl;
109             std::string out = test_patterns(vars, 
110                     sheader, 
111                     req_uri_pats, req_groups_bynum);
112             if (!out.empty()) 
113             {
114                 size_t pos = out.find(": ");
115                 if (pos == std::string::npos)
116                 {
117                     std::cout << "Header malformed during rewrite, ignoring";
118                     continue;
119                 }
120                 header->name = odr_strdup(o, out.substr(0, pos).c_str());
121                 header->value = odr_strdup(o, out.substr(pos+2, 
122                             std::string::npos).c_str());
123             }
124         }
125     }
126
127     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
128
129     /**
130      * Tests pattern from the vector in order and executes recipe on
131        the first match.
132      */
133     const std::string test_patterns(
134             std::map<std::string, std::string> & vars,
135             const std::string & txt, 
136             const spair_vec & uri_pats,
137             const std::vector<std::map<int, std::string> > & groups_bynum_vec)
138         const
139     {
140         for (int i = 0; i < uri_pats.size(); i++) 
141         {
142             std::string out = search_replace(vars, txt, 
143                     uri_pats[i].first, uri_pats[i].second,
144                     groups_bynum_vec[i]);
145             if (!out.empty()) return out;
146         }
147         return "";
148     }
149
150
151     const std::string search_replace(
152             std::map<std::string, std::string> & vars,
153             const std::string & txt,
154             const std::string & uri_re,
155             const std::string & uri_pat,
156             const std::map<int, std::string> & groups_bynum) const
157     {
158         //exec regex against value
159         boost::regex re(uri_re);
160         boost::smatch what;
161         std::string::const_iterator start, end;
162         start = txt.begin();
163         end = txt.end();
164         std::string out;
165         while (regex_search(start, end, what, re)) //find next full match
166         {
167             unsigned i;
168             for (i = 1; i < what.size(); ++i)
169             {
170                 //check if the group is named
171                 std::map<int, std::string>::const_iterator it
172                     = groups_bynum.find(i);
173                 if (it != groups_bynum.end()) 
174                 {   //it is
175                     std::string name = it->second;
176                     if (!what[i].str().empty())
177                         vars[name] = what[i];
178                 }
179
180             }
181             //prepare replacement string
182             std::string rvalue = sub_vars(uri_pat, vars);
183             //rewrite value
184             std::string rhvalue = what.prefix().str() 
185                 + rvalue + what.suffix().str();
186             std::cout << "! Rewritten '"+what.str(0)+"' to '"+rvalue+"'\n";
187             out += rhvalue;
188             start = what[0].second; //move search forward
189         }
190         return out;
191     }
192
193     static void parse_groups(
194             const spair_vec & uri_pats,
195             std::vector<std::map<int, std::string> > & groups_bynum_vec)
196     {
197         for (int h = 0; h < uri_pats.size(); h++) 
198         {
199             int gnum = 0;
200             bool esc = false;
201             //regex is first, subpat is second
202             std::string str = uri_pats[h].first;
203             //for each pair we have an indexing map
204             std::map<int, std::string> groups_bynum;
205             for (int i = 0; i < str.size(); ++i)
206             {
207                 if (!esc && str[i] == '\\')
208                 {
209                     esc = true;
210                     continue;
211                 }
212                 if (!esc && str[i] == '(') //group starts
213                 {
214                     gnum++;
215                     if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
216                     {
217                         i++;
218                         if (i+1 < str.size() && str[i+1] == ':') //non-capturing
219                         {
220                             if (gnum > 0) gnum--;
221                             i++;
222                             continue;
223                         }
224                         if (i+1 < str.size() && str[i+1] == 'P') //optional, python
225                             i++;
226                         if (i+1 < str.size() && str[i+1] == '<') //named
227                         {
228                             i++;
229                             std::string gname;
230                             bool term = false;
231                             while (++i < str.size())
232                             {
233                                 if (str[i] == '>') { term = true; break; }
234                                 if (!isalnum(str[i])) 
235                                     throw mp::filter::FilterException
236                                         ("Only alphanumeric chars allowed, found "
237                                          " in '" 
238                                          + str 
239                                          + "' at " 
240                                          + boost::lexical_cast<std::string>(i)); 
241                                 gname += str[i];
242                             }
243                             if (!term)
244                                 throw mp::filter::FilterException
245                                     ("Unterminated group name '" + gname 
246                                      + " in '" + str +"'");
247                             groups_bynum[gnum] = gname;
248                             std::cout << "Found named group '" << gname 
249                                 << "' at $" << gnum << std::endl;
250                         }
251                     }
252                 }
253                 esc = false;
254             }
255             groups_bynum_vec.push_back(groups_bynum);
256         }
257     }
258
259     static std::string sub_vars (const std::string & in, 
260             const std::map<std::string, std::string> & vars)
261     {
262         std::string out;
263         bool esc = false;
264         for (int i = 0; i < in.size(); ++i)
265         {
266             if (!esc && in[i] == '\\')
267             {
268                 esc = true;
269                 continue;
270             }
271             if (!esc && in[i] == '$') //var
272             {
273                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
274                 {
275                     ++i;
276                     std::string name;
277                     bool term = false;
278                     while (++i < in.size()) 
279                     {
280                         if (in[i] == '}') { term = true; break; }
281                         name += in[i];
282                     }
283                     if (!term) throw mp::filter::FilterException
284                         ("Unterminated var ref in '"+in+"' at "
285                          + boost::lexical_cast<std::string>(i));
286                     std::map<std::string, std::string>::const_iterator it
287                         = vars.find(name);
288                     if (it != vars.end())
289                     {
290                         out += it->second;
291                     }
292                 }
293                 else
294                 {
295                     throw mp::filter::FilterException
296                         ("Malformed or trimmed var ref in '"
297                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
298                 }
299                 continue;
300             }
301             //passthru
302             out += in[i];
303             esc = false;
304         }
305         return out;
306     }
307     
308     void configure(
309             const spair_vec req_uri_pats,
310             const spair_vec res_uri_pats)
311     {
312        //TODO should we really copy them out?
313        this->req_uri_pats = req_uri_pats;
314        this->res_uri_pats = res_uri_pats;
315        //pick up names
316        parse_groups(req_uri_pats, req_groups_bynum);
317        parse_groups(res_uri_pats, res_groups_bynum);
318     };
319
320 private:
321     spair_vec req_uri_pats;
322     spair_vec res_uri_pats;
323     std::vector<std::map<int, std::string> > req_groups_bynum;
324     std::vector<std::map<int, std::string> > res_groups_bynum;
325
326 };
327
328
329 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
330 {
331     try
332     {
333        FilterHeaderRewrite fhr;
334     }
335     catch ( ... ) {
336         BOOST_CHECK (false);
337     }
338 }
339
340 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
341 {
342     try
343     {
344         mp::RouterChain router;
345
346         FilterHeaderRewrite fhr;
347         
348         spair_vec vec_req;
349         vec_req.push_back(std::make_pair(
350         "(?<proto>http\\:\\/\\/s?)(?<pxhost>[^\\/?#]+)\\/(?<pxpath>[^\\/]+)"
351         "\\/(?<host>[^\\/]+)(?<path>.*)",
352         "${proto}${host}${path}"
353         ));
354         vec_req.push_back(std::make_pair(
355         "(?:Host\\: )(.*)",
356         "Host: localhost"
357         ));
358
359         spair_vec vec_res;
360         vec_res.push_back(std::make_pair(
361         "(?<proto>http\\:\\/\\/s?)(?<host>[^\\/?#]+)\\/(?<path>[^ >]+)",
362         "http://${pxhost}/${pxpath}/${host}/${path}"
363         ));
364         
365         fhr.configure(vec_req, vec_res);
366
367         mp::filter::HTTPClient hc;
368         
369         router.append(fhr);
370         router.append(hc);
371
372         // create an http request
373         mp::Package pack;
374
375         mp::odr odr;
376         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
377         "http://proxyhost/proxypath/localhost:80/~jakub/targetsite.php", 0, 1);
378
379         pack.request() = gdu_req;
380
381         //feed to the router
382         pack.router(router).move();
383
384         //analyze the response
385         Z_GDU *gdu_res = pack.response().get();
386         BOOST_CHECK(gdu_res);
387         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
388         
389         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
390         BOOST_CHECK(hres);
391
392     }
393     catch (std::exception & e) {
394         std::cout << e.what();
395         BOOST_CHECK (false);
396     }
397 }
398
399 /*
400  * Local variables:
401  * c-basic-offset: 4
402  * c-file-style: "Stroustrup"
403  * indent-tabs-mode: nil
404  * End:
405  * vim: shiftwidth=4 tabstop=8 expandtab
406  */
407