Reconstruct headers before matching
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 typedef std::pair<std::string, std::string> string_pair;
40 typedef std::vector<string_pair> spair_vec;
41 typedef spair_vec::iterator spv_iter;
42
43 class FilterHeaderRewrite: public mp::filter::Base {
44 public:
45     void process(mp::Package & package) const {
46         Z_GDU *gdu = package.request().get();
47         //map of request/response vars
48         std::map<std::string, std::string> vars;
49         //we have an http req
50         if (gdu && gdu->which == Z_GDU_HTTP_Request)
51         {
52             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
53             mp::odr o;
54             //rewrite the request line
55             std::string path;
56             if (strstr(hreq->path, "http://") == hreq->path)
57             {
58                 std::cout << "Path in the method line is absolute, " 
59                     "possibly a proxy request\n";
60                 path += hreq->path;
61             }
62             else
63             {
64                 //TODO what about proto
65                path += z_HTTP_header_lookup(hreq->headers, "Host");
66                path += hreq->path; 
67             }
68             std::cout << "Proxy request URL is " << path << std::endl;
69             std::string npath = 
70                 test_patterns(vars, path, req_uri_pats, req_groups_bynum);
71             std::cout << "Resp request URL is " << npath << std::endl;
72             if (!npath.empty())
73                 hreq->path = odr_strdup(o, npath.c_str());
74             std::cout << ">> Request headers" << std::endl;
75             //iterate headers
76             for (Z_HTTP_Header *header = hreq->headers;
77                     header != 0; 
78                     header = header->next) 
79             {
80                 std::string sheader(header->name);
81                 sheader += ": ";
82                 sheader += header->value;
83                 std::cout << header->name << ": " << header->value << std::endl;
84                 std::string out = test_patterns(vars, 
85                         sheader, 
86                         req_uri_pats, req_groups_bynum);
87                 if (!out.empty()) 
88                 {
89                     size_t pos = out.find(": ");
90                     if (pos == std::string::npos)
91                     {
92                         std::cout << "Header malformed during rewrite, ignoring";
93                         continue;
94                     }
95                     header->name = odr_strdup(o, out.substr(0, pos).c_str());
96                     header->value = odr_strdup(o, out.substr(pos+2, 
97                                 std::string::npos).c_str());
98                 }
99             }
100             package.request() = gdu;
101         }
102         package.move();
103         gdu = package.response().get();
104         if (gdu && gdu->which == Z_GDU_HTTP_Response)
105         {
106             Z_HTTP_Response *hr = gdu->u.HTTP_Response;
107             std::cout << "Response " << hr->code;
108             std::cout << "<< Respose headers" << std::endl;
109             mp::odr o;
110             //iterate headers
111             for (Z_HTTP_Header *header = hr->headers;
112                     header != 0; 
113                     header = header->next) 
114             {
115                 std::cout << header->name << ": " << header->value << std::endl;
116                 std::string out = test_patterns(vars,
117                         std::string(header->value), 
118                         res_uri_pats, res_groups_bynum); 
119                 if (!out.empty())
120                     header->value = odr_strdup(o, out.c_str());
121             }
122             package.response() = gdu;
123         }
124     };
125
126     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
127
128     /**
129      * Tests pattern from the vector in order and executes recipe on
130        the first match.
131      */
132     const std::string test_patterns(
133             std::map<std::string, std::string> & vars,
134             const std::string & txt, 
135             const spair_vec & uri_pats,
136             const std::vector<std::map<int, std::string> > & groups_bynum_vec)
137         const
138     {
139         for (int i = 0; i < uri_pats.size(); i++) 
140         {
141             std::string out = search_replace(vars, txt, 
142                     uri_pats[i].first, uri_pats[i].second,
143                     groups_bynum_vec[i]);
144             if (!out.empty()) return out;
145         }
146         return "";
147     }
148
149
150     const std::string search_replace(
151             std::map<std::string, std::string> & vars,
152             const std::string & txt,
153             const std::string & uri_re,
154             const std::string & uri_pat,
155             const std::map<int, std::string> & groups_bynum) const
156     {
157         //exec regex against value
158         boost::regex re(uri_re);
159         boost::smatch what;
160         std::string::const_iterator start, end;
161         start = txt.begin();
162         end = txt.end();
163         std::string out;
164         while (regex_search(start, end, what, re)) //find next full match
165         {
166             unsigned i;
167             for (i = 1; i < what.size(); ++i)
168             {
169                 //check if the group is named
170                 std::map<int, std::string>::const_iterator it
171                     = groups_bynum.find(i);
172                 if (it != groups_bynum.end()) 
173                 {   //it is
174                     std::string name = it->second;
175                     if (!what[i].str().empty())
176                         vars[name] = what[i];
177                 }
178
179             }
180             //prepare replacement string
181             std::string rvalue = sub_vars(uri_pat, vars);
182             //rewrite value
183             std::string rhvalue = what.prefix().str() 
184                 + rvalue + what.suffix().str();
185             std::cout << "! Rewritten '"+what.str(0)+"' to '"+rvalue+"'\n";
186             out += rhvalue;
187             start = what[0].second; //move search forward
188         }
189         return out;
190     }
191
192     static void parse_groups(
193             const spair_vec & uri_pats,
194             std::vector<std::map<int, std::string> > & groups_bynum_vec)
195     {
196         for (int h = 0; h < uri_pats.size(); h++) 
197         {
198             int gnum = 0;
199             bool esc = false;
200             //regex is first, subpat is second
201             std::string str = uri_pats[h].first;
202             //for each pair we have an indexing map
203             std::map<int, std::string> groups_bynum;
204             for (int i = 0; i < str.size(); ++i)
205             {
206                 if (!esc && str[i] == '\\')
207                 {
208                     esc = true;
209                     continue;
210                 }
211                 if (!esc && str[i] == '(') //group starts
212                 {
213                     gnum++;
214                     if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
215                     {
216                         i++;
217                         if (i+1 < str.size() && str[i+1] == ':') //non-capturing
218                         {
219                             if (gnum > 0) gnum--;
220                             i++;
221                             continue;
222                         }
223                         if (i+1 < str.size() && str[i+1] == 'P') //optional, python
224                             i++;
225                         if (i+1 < str.size() && str[i+1] == '<') //named
226                         {
227                             i++;
228                             std::string gname;
229                             bool term = false;
230                             while (++i < str.size())
231                             {
232                                 if (str[i] == '>') { term = true; break; }
233                                 if (!isalnum(str[i])) 
234                                     throw mp::filter::FilterException
235                                         ("Only alphanumeric chars allowed, found "
236                                          " in '" 
237                                          + str 
238                                          + "' at " 
239                                          + boost::lexical_cast<std::string>(i)); 
240                                 gname += str[i];
241                             }
242                             if (!term)
243                                 throw mp::filter::FilterException
244                                     ("Unterminated group name '" + gname 
245                                      + " in '" + str +"'");
246                             groups_bynum[gnum] = gname;
247                             std::cout << "Found named group '" << gname 
248                                 << "' at $" << gnum << std::endl;
249                         }
250                     }
251                 }
252                 esc = false;
253             }
254             groups_bynum_vec.push_back(groups_bynum);
255         }
256     }
257
258     static std::string sub_vars (const std::string & in, 
259             const std::map<std::string, std::string> & vars)
260     {
261         std::string out;
262         bool esc = false;
263         for (int i = 0; i < in.size(); ++i)
264         {
265             if (!esc && in[i] == '\\')
266             {
267                 esc = true;
268                 continue;
269             }
270             if (!esc && in[i] == '$') //var
271             {
272                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
273                 {
274                     ++i;
275                     std::string name;
276                     bool term = false;
277                     while (++i < in.size()) 
278                     {
279                         if (in[i] == '}') { term = true; break; }
280                         name += in[i];
281                     }
282                     if (!term) throw mp::filter::FilterException
283                         ("Unterminated var ref in '"+in+"' at "
284                          + boost::lexical_cast<std::string>(i));
285                     std::map<std::string, std::string>::const_iterator it
286                         = vars.find(name);
287                     if (it != vars.end())
288                     {
289                         out += it->second;
290                     }
291                 }
292                 else
293                 {
294                     throw mp::filter::FilterException
295                         ("Malformed or trimmed var ref in '"
296                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
297                 }
298                 continue;
299             }
300             //passthru
301             out += in[i];
302             esc = false;
303         }
304         return out;
305     }
306     
307     void configure(
308             const spair_vec req_uri_pats,
309             const spair_vec res_uri_pats)
310     {
311        //TODO should we really copy them out?
312        this->req_uri_pats = req_uri_pats;
313        this->res_uri_pats = res_uri_pats;
314        //pick up names
315        parse_groups(req_uri_pats, req_groups_bynum);
316        parse_groups(res_uri_pats, res_groups_bynum);
317     };
318
319 private:
320     std::map<std::string, std::string> vars;
321     spair_vec req_uri_pats;
322     spair_vec res_uri_pats;
323     std::vector<std::map<int, std::string> > req_groups_bynum;
324     std::vector<std::map<int, std::string> > res_groups_bynum;
325
326 };
327
328
329 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
330 {
331     try
332     {
333        FilterHeaderRewrite fhr;
334     }
335     catch ( ... ) {
336         BOOST_CHECK (false);
337     }
338 }
339
340 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
341 {
342     try
343     {
344         mp::RouterChain router;
345
346         FilterHeaderRewrite fhr;
347         
348         spair_vec vec_req;
349         vec_req.push_back(std::make_pair(
350         "(?<proto>http\\:\\/\\/s?)(?<pxhost>[^\\/?#]+)\\/(?<pxpath>[^\\/]+)"
351         "\\/(?<host>[^\\/]+)(?<path>.*)",
352         "${proto}${host}${path}"
353         ));
354         vec_req.push_back(std::make_pair(
355         "(?:Host\\: )(.*)",
356         "Host: localhost"
357         ));
358
359         spair_vec vec_res;
360         vec_res.push_back(std::make_pair(
361         "(?<proto>http\\:\\/\\/s?)(?<host>[^\\/?#]+)\\/(?<path>[^ >]+)",
362         "http://${pxhost}/${pxpath}/${host}/${path}"
363         ));
364         
365         fhr.configure(vec_req, vec_res);
366
367         mp::filter::HTTPClient hc;
368         
369         router.append(fhr);
370         router.append(hc);
371
372         // create an http request
373         mp::Package pack;
374
375         mp::odr odr;
376         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
377         "http://proxyhost/proxypath/localhost:80/~jakub/targetsite.php", 0, 1);
378
379         pack.request() = gdu_req;
380
381         //feed to the router
382         pack.router(router).move();
383
384         //analyze the response
385         Z_GDU *gdu_res = pack.response().get();
386         BOOST_CHECK(gdu_res);
387         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
388         
389         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
390         BOOST_CHECK(hres);
391
392     }
393     catch (std::exception & e) {
394         std::cout << e.what();
395         BOOST_CHECK (false);
396     }
397 }
398
399 /*
400  * Local variables:
401  * c-basic-offset: 4
402  * c-file-style: "Stroustrup"
403  * indent-tabs-mode: nil
404  * End:
405  * vim: shiftwidth=4 tabstop=8 expandtab
406  */
407