Only remember match if not empty
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #include <boost/regex.hpp>
29 #include <boost/lexical_cast.hpp>
30
31 #define BOOST_AUTO_TEST_MAIN
32 #define BOOST_TEST_DYN_LINK
33
34 #include <boost/test/auto_unit_test.hpp>
35
36 using namespace boost::unit_test;
37 namespace mp = metaproxy_1;
38
39 class FilterHeaderRewrite: public mp::filter::Base {
40 public:
41     void process(mp::Package & package) const {
42         Z_GDU *gdu = package.request().get();
43         //map of request/response vars
44         std::map<std::string, std::string> vars;
45         //we have an http req
46         if (gdu && gdu->which == Z_GDU_HTTP_Request)
47         {
48             Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
49             mp::odr o;
50             //rewrite the request line
51             std::string path;
52             if (strstr(hreq->path, "http://") == hreq->path)
53             {
54                 std::cout << "Path in the method line is absolute, " 
55                     "possibly a proxy request\n";
56                 path += hreq->path;
57             }
58             else
59             {
60                path += z_HTTP_header_lookup(hreq->headers, "Host");
61                path += hreq->path; 
62             }
63             std::cout << "Proxy request URL is " << path << std::endl;
64             std::string npath = 
65                 search_replace(vars, path, req_uri_rx, req_uri_pat);
66             std::cout << "Resp request URL is " << npath << std::endl;
67             if (!npath.empty())
68                 hreq->path = odr_strdup(o, npath.c_str());
69             std::cout << ">> Request headers" << std::endl;
70             //iterate headers
71             for (Z_HTTP_Header *header = hreq->headers;
72                     header != 0; 
73                     header = header->next) 
74             {
75                 std::cout << header->name << ": " << header->value << std::endl;
76                 std::string out = search_replace(vars, 
77                         std::string(header->value), 
78                         req_uri_rx, req_uri_pat);
79                 if (!out.empty())
80                     header->value = odr_strdup(o, out.c_str());
81             }
82             package.request() = gdu;
83         }
84         package.move();
85         gdu = package.response().get();
86         if (gdu && gdu->which == Z_GDU_HTTP_Response)
87         {
88             Z_HTTP_Response *hr = gdu->u.HTTP_Response;
89             std::cout << "Response " << hr->code;
90             std::cout << "<< Respose headers" << std::endl;
91             mp::odr o;
92             //iterate headers
93             for (Z_HTTP_Header *header = hr->headers;
94                     header != 0; 
95                     header = header->next) 
96             {
97                 std::cout << header->name << ": " << header->value << std::endl;
98                 std::string out = search_replace(vars, 
99                         std::string(header->value), 
100                         resp_uri_rx, resp_uri_pat);
101                 if (!out.empty())
102                     header->value = odr_strdup(o, out.c_str());
103             }
104             package.response() = gdu;
105         }
106     };
107
108     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
109
110     const std::string search_replace(
111             std::map<std::string, std::string> & vars,
112             const std::string txt,
113             const std::string & uri_re,
114             const std::string & uri_pat) const
115     {
116         //exec regex against value
117         boost::regex re(uri_re);
118         boost::smatch what;
119         std::string::const_iterator start, end;
120         start = txt.begin();
121         end = txt.end();
122         std::string out;
123         while (regex_search(start, end, what, re)) //find next full match
124         {
125             unsigned i;
126             for (i = 1; i < what.size(); ++i)
127             {
128                 //check if the group is named
129                 std::map<int, std::string>::const_iterator it
130                     = groups_by_num.find(i);
131                 if (it != groups_by_num.end()) 
132                 {   //it is
133                     std::string name = it->second;
134                     if (!what[i].str().empty())
135                         vars[name] = what[i];
136                 }
137
138             }
139             //prepare replacement string
140             std::string rvalue = sub_vars(uri_pat, vars);
141             //rewrite value
142             std::string rhvalue = what.prefix().str() 
143                 + rvalue + what.suffix().str();
144             std::cout << "! Rewritten '"+what.str(0)+"' to '"+rvalue+"'\n";
145             out += rhvalue;
146             start = what[0].second; //move search forward
147         }
148         return out;
149     };
150
151     static void parse_groups(const std::string & str,
152             std::map<int, std::string> & groups_bynum,
153             std::map<std::string, int> & groups_byname)
154     {
155        int gnum = 0;
156        bool esc = false;
157        for (int i = 0; i < str.size(); ++i)
158        {
159            if (!esc && str[i] == '\\')
160            {
161                esc = true;
162                continue;
163            }
164            if (!esc && str[i] == '(') //group starts
165            {
166                gnum++;
167                if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
168                {
169                    i++;
170                    if (i+1 < str.size() && str[i+1] == ':') //non-capturing
171                    {
172                        if (gnum > 0) gnum--;
173                        i++;
174                        continue;
175                    }
176                    if (i+1 < str.size() && str[i+1] == 'P') //optional, python
177                        i++;
178                    if (i+1 < str.size() && str[i+1] == '<') //named
179                    {
180                        i++;
181                        std::string gname;
182                        bool term = false;
183                        while (++i < str.size())
184                        {
185                            if (str[i] == '>') { term = true; break; }
186                            if (!isalnum(str[i])) 
187                                throw mp::filter::FilterException
188                                    ("Only alphanumeric chars allowed, found "
189                                     " in '" 
190                                     + str 
191                                     + "' at " 
192                                     + boost::lexical_cast<std::string>(i)); 
193                            gname += str[i];
194                        }
195                        if (!term)
196                            throw mp::filter::FilterException
197                                ("Unterminated group name '" + gname 
198                                 + " in '" + str +"'");
199                       groups_bynum[gnum] = gname;
200                       groups_byname[gname] = gnum;
201                       std::cout << "Found named group '" << gname 
202                           << "' at $" << gnum << std::endl;
203                    }
204                }
205            }
206            esc = false;
207        }
208     }
209
210     static std::string sub_vars (const std::string & in, 
211             const std::map<std::string, std::string> & vars)
212     {
213         std::string out;
214         bool esc = false;
215         for (int i = 0; i < in.size(); ++i)
216         {
217             if (!esc && in[i] == '\\')
218             {
219                 esc = true;
220                 continue;
221             }
222             if (!esc && in[i] == '$') //var
223             {
224                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
225                 {
226                     ++i;
227                     std::string name;
228                     bool term = false;
229                     while (++i < in.size()) 
230                     {
231                         if (in[i] == '}') { term = true; break; }
232                         name += in[i];
233                     }
234                     if (!term) throw mp::filter::FilterException
235                         ("Unterminated var ref in '"+in+"' at "
236                          + boost::lexical_cast<std::string>(i));
237                     std::map<std::string, std::string>::const_iterator it
238                         = vars.find(name);
239                     if (it != vars.end())
240                     {
241                         out += it->second;
242                     }
243                 }
244                 else
245                 {
246                     throw mp::filter::FilterException
247                         ("Malformed or trimmed var ref in '"
248                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
249                 }
250                 continue;
251             }
252             //passthru
253             out += in[i];
254             esc = false;
255         }
256         return out;
257     }
258     
259     void configure(
260             const std::string & req_uri_rx, 
261             const std::string & req_uri_pat, 
262             const std::string & resp_uri_rx, 
263             const std::string & resp_uri_pat) 
264     {
265        this->req_uri_rx = req_uri_rx;
266        this->req_uri_pat = req_uri_pat;
267        //pick up names
268        parse_groups(req_uri_rx, groups_by_num, groups_by_name);
269        this->resp_uri_rx = resp_uri_rx;
270        this->resp_uri_pat = resp_uri_pat;
271     };
272
273 private:
274     std::map<std::string, std::string> vars;
275     std::string req_uri_rx;
276     std::string resp_uri_rx;
277     std::string req_uri_pat;
278     std::string resp_uri_pat;
279     std::map<int, std::string> groups_by_num;
280     std::map<std::string, int> groups_by_name;
281
282 };
283
284
285 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
286 {
287     try
288     {
289        FilterHeaderRewrite fhr;
290     }
291     catch ( ... ) {
292         BOOST_CHECK (false);
293     }
294 }
295
296 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
297 {
298     try
299     {
300         mp::RouterChain router;
301
302         FilterHeaderRewrite fhr;
303         fhr.configure(
304     "((?<proto>http\\:\\/\\/s?)(?<pxhost>[^\\/?#]+)\\/(?<pxpath>[^\\/]+)"
305     "(?<target>.+))|(proxyhost)",
306                 "${proto}${target}${whatever}",
307                 //rewrite connection close
308                 "close",
309                 "open for ${host}");
310
311         mp::filter::HTTPClient hc;
312         
313         router.append(fhr);
314         router.append(hc);
315
316         // create an http request
317         mp::Package pack;
318
319         mp::odr odr;
320         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
321         "http://proxyhost/proxypath/localhost:80/~jakub/targetsite.php", 0, 1);
322
323         pack.request() = gdu_req;
324
325         //feed to the router
326         pack.router(router).move();
327
328         //analyze the response
329         Z_GDU *gdu_res = pack.response().get();
330         BOOST_CHECK(gdu_res);
331         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
332         
333         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
334         BOOST_CHECK(hres);
335
336     }
337     catch (std::exception & e) {
338         std::cout << e.what();
339         BOOST_CHECK (false);
340     }
341 }
342
343 /*
344  * Local variables:
345  * c-basic-offset: 4
346  * c-file-style: "Stroustrup"
347  * indent-tabs-mode: nil
348  * End:
349  * vim: shiftwidth=4 tabstop=8 expandtab
350  */
351