Rewrite body too
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 typedef std::pair<std::string, std::string> string_pair;
40 typedef std::vector<string_pair> spair_vec;
41 typedef spair_vec::iterator spv_iter;
42
43 class FilterHeaderRewrite: public mp::filter::Base {
44 public:
45     void process(mp::Package & package) const 
46     {
47         Z_GDU *gdu = package.request().get();
48         //map of request/response vars
49         std::map<std::string, std::string> vars;
50         //we have an http req
51         if (gdu && gdu->which == Z_GDU_HTTP_Request)
52         {
53             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
54             mp::odr o;
55             std::cout << ">> Request headers" << std::endl;
56             rewrite_reqline(o, hreq, vars);
57             rewrite_headers(o, hreq->headers, vars);
58             rewrite_body(o, &hreq->content_buf, &hreq->content_len, vars);
59             package.request() = gdu;
60         }
61         package.move();
62         gdu = package.response().get();
63         if (gdu && gdu->which == Z_GDU_HTTP_Response)
64         {
65             Z_HTTP_Response *hres = gdu->u.HTTP_Response;
66             std::cout << "Response " << hres->code;
67             std::cout << "<< Respose headers" << std::endl;
68             mp::odr o;
69             rewrite_headers(o, hres->headers, vars);
70             rewrite_body(o, &hres->content_buf, &hres->content_len, vars);
71             package.response() = gdu;
72         }
73     }
74
75     void rewrite_reqline (mp::odr & o, Z_HTTP_Request *hreq,
76             std::map<std::string, std::string> & vars) const 
77     {
78         //rewrite the request line
79         std::string path;
80         if (strstr(hreq->path, "http://") == hreq->path)
81         {
82             std::cout << "Path in the method line is absolute, " 
83                 "possibly a proxy request\n";
84             path += hreq->path;
85         }
86         else
87         {
88             //TODO what about proto
89             path += z_HTTP_header_lookup(hreq->headers, "Host");
90             path += hreq->path; 
91         }
92         std::cout << "Proxy request URL is " << path << std::endl;
93         std::string npath = 
94             test_patterns(vars, path, req_uri_pats, req_groups_bynum);
95         std::cout << "Resp request URL is " << npath << std::endl;
96         if (!npath.empty())
97             hreq->path = odr_strdup(o, npath.c_str());
98     }
99  
100     void rewrite_headers (mp::odr & o, Z_HTTP_Header *headers,
101             std::map<std::string, std::string> & vars) const 
102     {
103         for (Z_HTTP_Header *header = headers;
104                 header != 0; 
105                 header = header->next) 
106         {
107             std::string sheader(header->name);
108             sheader += ": ";
109             sheader += header->value;
110             std::cout << header->name << ": " << header->value << std::endl;
111             std::string out = test_patterns(vars, 
112                     sheader, 
113                     req_uri_pats, req_groups_bynum);
114             if (!out.empty()) 
115             {
116                 size_t pos = out.find(": ");
117                 if (pos == std::string::npos)
118                 {
119                     std::cout << "Header malformed during rewrite, ignoring";
120                     continue;
121                 }
122                 header->name = odr_strdup(o, out.substr(0, pos).c_str());
123                 header->value = odr_strdup(o, out.substr(pos+2, 
124                             std::string::npos).c_str());
125             }
126         }
127     }
128
129     void rewrite_body (mp::odr & o, char **content_buf, int *content_len,
130             std::map<std::string, std::string> & vars) const 
131     {
132         if (*content_buf)
133         {
134             std::string body(*content_buf);
135             std::string nbody = 
136                 test_patterns(vars, body, req_uri_pats, req_groups_bynum);
137             if (!nbody.empty())
138             {
139                 *content_buf = odr_strdup(o, nbody.c_str());
140                 *content_len = nbody.size();
141             }
142         }
143     }
144
145
146     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
147
148     /**
149      * Tests pattern from the vector in order and executes recipe on
150        the first match.
151      */
152     const std::string test_patterns(
153             std::map<std::string, std::string> & vars,
154             const std::string & txt, 
155             const spair_vec & uri_pats,
156             const std::vector<std::map<int, std::string> > & groups_bynum_vec)
157         const
158     {
159         for (int i = 0; i < uri_pats.size(); i++) 
160         {
161             std::string out = search_replace(vars, txt, 
162                     uri_pats[i].first, uri_pats[i].second,
163                     groups_bynum_vec[i]);
164             if (!out.empty()) return out;
165         }
166         return "";
167     }
168
169
170     const std::string search_replace(
171             std::map<std::string, std::string> & vars,
172             const std::string & txt,
173             const std::string & uri_re,
174             const std::string & uri_pat,
175             const std::map<int, std::string> & groups_bynum) const
176     {
177         //exec regex against value
178         boost::regex re(uri_re);
179         boost::smatch what;
180         std::string::const_iterator start, end;
181         start = txt.begin();
182         end = txt.end();
183         std::string out;
184         while (regex_search(start, end, what, re)) //find next full match
185         {
186             unsigned i;
187             for (i = 1; i < what.size(); ++i)
188             {
189                 //check if the group is named
190                 std::map<int, std::string>::const_iterator it
191                     = groups_bynum.find(i);
192                 if (it != groups_bynum.end()) 
193                 {   //it is
194                     std::string name = it->second;
195                     if (!what[i].str().empty())
196                         vars[name] = what[i];
197                 }
198
199             }
200             //prepare replacement string
201             std::string rvalue = sub_vars(uri_pat, vars);
202             //rewrite value
203             std::string rhvalue = what.prefix().str() 
204                 + rvalue + what.suffix().str();
205             std::cout << "! Rewritten '"+what.str(0)+"' to '"+rvalue+"'\n";
206             out += rhvalue;
207             start = what[0].second; //move search forward
208         }
209         return out;
210     }
211
212     static void parse_groups(
213             const spair_vec & uri_pats,
214             std::vector<std::map<int, std::string> > & groups_bynum_vec)
215     {
216         for (int h = 0; h < uri_pats.size(); h++) 
217         {
218             int gnum = 0;
219             bool esc = false;
220             //regex is first, subpat is second
221             std::string str = uri_pats[h].first;
222             //for each pair we have an indexing map
223             std::map<int, std::string> groups_bynum;
224             for (int i = 0; i < str.size(); ++i)
225             {
226                 if (!esc && str[i] == '\\')
227                 {
228                     esc = true;
229                     continue;
230                 }
231                 if (!esc && str[i] == '(') //group starts
232                 {
233                     gnum++;
234                     if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
235                     {
236                         i++;
237                         if (i+1 < str.size() && str[i+1] == ':') //non-capturing
238                         {
239                             if (gnum > 0) gnum--;
240                             i++;
241                             continue;
242                         }
243                         if (i+1 < str.size() && str[i+1] == 'P') //optional, python
244                             i++;
245                         if (i+1 < str.size() && str[i+1] == '<') //named
246                         {
247                             i++;
248                             std::string gname;
249                             bool term = false;
250                             while (++i < str.size())
251                             {
252                                 if (str[i] == '>') { term = true; break; }
253                                 if (!isalnum(str[i])) 
254                                     throw mp::filter::FilterException
255                                         ("Only alphanumeric chars allowed, found "
256                                          " in '" 
257                                          + str 
258                                          + "' at " 
259                                          + boost::lexical_cast<std::string>(i)); 
260                                 gname += str[i];
261                             }
262                             if (!term)
263                                 throw mp::filter::FilterException
264                                     ("Unterminated group name '" + gname 
265                                      + " in '" + str +"'");
266                             groups_bynum[gnum] = gname;
267                             std::cout << "Found named group '" << gname 
268                                 << "' at $" << gnum << std::endl;
269                         }
270                     }
271                 }
272                 esc = false;
273             }
274             groups_bynum_vec.push_back(groups_bynum);
275         }
276     }
277
278     static std::string sub_vars (const std::string & in, 
279             const std::map<std::string, std::string> & vars)
280     {
281         std::string out;
282         bool esc = false;
283         for (int i = 0; i < in.size(); ++i)
284         {
285             if (!esc && in[i] == '\\')
286             {
287                 esc = true;
288                 continue;
289             }
290             if (!esc && in[i] == '$') //var
291             {
292                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
293                 {
294                     ++i;
295                     std::string name;
296                     bool term = false;
297                     while (++i < in.size()) 
298                     {
299                         if (in[i] == '}') { term = true; break; }
300                         name += in[i];
301                     }
302                     if (!term) throw mp::filter::FilterException
303                         ("Unterminated var ref in '"+in+"' at "
304                          + boost::lexical_cast<std::string>(i));
305                     std::map<std::string, std::string>::const_iterator it
306                         = vars.find(name);
307                     if (it != vars.end())
308                     {
309                         out += it->second;
310                     }
311                 }
312                 else
313                 {
314                     throw mp::filter::FilterException
315                         ("Malformed or trimmed var ref in '"
316                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
317                 }
318                 continue;
319             }
320             //passthru
321             out += in[i];
322             esc = false;
323         }
324         return out;
325     }
326     
327     void configure(
328             const spair_vec req_uri_pats,
329             const spair_vec res_uri_pats)
330     {
331        //TODO should we really copy them out?
332        this->req_uri_pats = req_uri_pats;
333        this->res_uri_pats = res_uri_pats;
334        //pick up names
335        parse_groups(req_uri_pats, req_groups_bynum);
336        parse_groups(res_uri_pats, res_groups_bynum);
337     };
338
339 private:
340     spair_vec req_uri_pats;
341     spair_vec res_uri_pats;
342     std::vector<std::map<int, std::string> > req_groups_bynum;
343     std::vector<std::map<int, std::string> > res_groups_bynum;
344
345 };
346
347
348 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
349 {
350     try
351     {
352        FilterHeaderRewrite fhr;
353     }
354     catch ( ... ) {
355         BOOST_CHECK (false);
356     }
357 }
358
359 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
360 {
361     try
362     {
363         mp::RouterChain router;
364
365         FilterHeaderRewrite fhr;
366         
367         spair_vec vec_req;
368         vec_req.push_back(std::make_pair(
369         "(?<proto>http\\:\\/\\/s?)(?<pxhost>[^\\/?#]+)\\/(?<pxpath>[^\\/]+)"
370         "\\/(?<host>[^\\/]+)(?<path>.*)",
371         "${proto}${host}${path}"
372         ));
373         vec_req.push_back(std::make_pair(
374         "(?:Host\\: )(.*)",
375         "Host: localhost"
376         ));
377
378         spair_vec vec_res;
379         vec_res.push_back(std::make_pair(
380         "(?<proto>http\\:\\/\\/s?)(?<host>[^\\/?#]+)\\/(?<path>[^ >]+)",
381         "http://${pxhost}/${pxpath}/${host}/${path}"
382         ));
383         
384         fhr.configure(vec_req, vec_res);
385
386         mp::filter::HTTPClient hc;
387         
388         router.append(fhr);
389         router.append(hc);
390
391         // create an http request
392         mp::Package pack;
393
394         mp::odr odr;
395         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
396         "http://proxyhost/proxypath/localhost:80/~jakub/targetsite.php", 0, 1);
397
398         pack.request() = gdu_req;
399
400         //feed to the router
401         pack.router(router).move();
402
403         //analyze the response
404         Z_GDU *gdu_res = pack.response().get();
405         BOOST_CHECK(gdu_res);
406         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
407         
408         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
409         BOOST_CHECK(hres);
410
411     }
412     catch (std::exception & e) {
413         std::cout << e.what();
414         BOOST_CHECK (false);
415     }
416 }
417
418 /*
419  * Local variables:
420  * c-basic-offset: 4
421  * c-file-style: "Stroustrup"
422  * indent-tabs-mode: nil
423  * End:
424  * vim: shiftwidth=4 tabstop=8 expandtab
425  */
426