Functional request header rewriting
[metaproxy-moved-to-github.git] / src / test_filter_rewrite.cpp
1 /* This file is part of Metaproxy.
2    Copyright (C) 2005-2013 Index Data
3
4 Metaproxy is free software; you can redistribute it and/or modify it under
5 the terms of the GNU General Public License as published by the Free
6 Software Foundation; either version 2, or (at your option) any later
7 version.
8
9 Metaproxy is distributed in the hope that it will be useful, but WITHOUT ANY
10 WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17 */
18
19 #include "config.hpp"
20 #include <iostream>
21 #include <stdexcept>
22
23 #include "filter_http_client.hpp"
24 #include <metaproxy/util.hpp>
25 #include "router_chain.hpp"
26 #include <metaproxy/package.hpp>
27
28 #define BOOST_REGEX_MATCH_EXTRA
29
30 #include <boost/xpressive/xpressive.hpp>
31 #include <boost/lexical_cast.hpp>
32
33 #define BOOST_AUTO_TEST_MAIN
34 #define BOOST_TEST_DYN_LINK
35
36 #include <boost/test/auto_unit_test.hpp>
37
38 using namespace boost::unit_test;
39 using namespace boost::xpressive;
40 namespace mp = metaproxy_1;
41
42 class FilterHeaderRewrite: public mp::filter::Base {
43 public:
44     void process(mp::Package & package) const {
45         Z_GDU *gdu = package.request().get();
46         //we have an http req
47         if (gdu && gdu->which == Z_GDU_HTTP_Request)
48         {
49           std::cout << "Request headers" << std::endl;
50           Z_HTTP_Request *hreq = gdu->u.HTTP_Request;
51           //dump req headers
52           for (Z_HTTP_Header *header = hreq->headers;
53               header != 0; 
54               header = header->next) 
55           {
56             std::cout << header->name << ": " << header->value << std::endl;
57             rewrite_req_header(header);
58           }
59         }
60         package.move();
61         gdu = package.response().get();
62         if (gdu && gdu->which == Z_GDU_HTTP_Response)
63         {
64           std::cout << "Respose headers" << std::endl;
65           Z_HTTP_Response *hr = gdu->u.HTTP_Response;
66           //dump resp headers
67           for (Z_HTTP_Header *header = hr->headers;
68               header != 0; 
69               header = header->next) 
70           {
71             std::cout << header->name << ": " << header->value << std::endl;
72           }
73         }
74
75     };
76     void configure(const xmlNode* ptr, bool test_only, const char *path) {};
77
78     void rewrite_req_header(Z_HTTP_Header *header) const
79     {
80         //exec regex against value
81         sregex re = sregex::compile(req_uri_rx);
82         smatch what;
83         std::string hvalue(header->value);
84         std::map<std::string, std::string> vars;
85         if (regex_match(hvalue, what, re))
86         {
87             unsigned i;
88             for (i = 1; i < what.size(); ++i)
89             {
90                 //check if the group is named
91                 std::map<int, std::string>::const_iterator it
92                     = groups_by_num.find(i);
93                 if (it != groups_by_num.end()) 
94                 {   //it is
95                     std::string name = it->second;
96                     vars[name] = what[i];
97                 }
98             }
99             //rewrite the header according to the recipe
100             std::string rvalue = sub_vars(req_uri_pat, vars);
101             std::cout << "Rewritten '"+hvalue+"' to '"+rvalue+"'\n";
102         }
103         else
104         {
105             std::cout << "No match found in '" + hvalue + "'\n";
106         }
107     };
108
109     static void parse_groups(const std::string & str,
110             std::map<int, std::string> & groups_bynum,
111             std::map<std::string, int> & groups_byname)
112     {
113        int gnum = 0;
114        bool esc = false;
115        for (int i = 0; i < str.size(); ++i)
116        {
117            if (!esc && str[i] == '\\')
118            {
119                esc = true;
120                continue;
121            }
122            if (!esc && str[i] == '(') //group starts
123            {
124                gnum++;
125                if (i+1 < str.size() && str[i+1] == '?') //group with attrs 
126                {
127                    i++;
128                    if (i+1 < str.size() && str[i+1] == 'P') //optional, python
129                        i++;
130                    if (i+1 < str.size() && str[i+1] == '<') //named
131                    {
132                        i++;
133                        std::string gname;
134                        bool term = false;
135                        while (++i < str.size())
136                        {
137                            if (str[i] == '>') { term = true; break; }
138                            if (!isalnum(str[i])) 
139                                throw mp::filter::FilterException
140                                    ("Only alphanumeric chars allowed, found "
141                                     " in '" 
142                                     + str 
143                                     + "' at " 
144                                     + boost::lexical_cast<std::string>(i)); 
145                            gname += str[i];
146                        }
147                        if (!term)
148                            throw mp::filter::FilterException
149                                ("Unterminated group name '" + gname 
150                                 + " in '" + str +"'");
151                       groups_bynum[gnum] = gname;
152                       groups_byname[gname] = gnum;
153                       std::cout << "Found named group '" << gname 
154                           << "' at $" << gnum << std::endl;
155                    }
156                }
157            }
158            esc = false;
159        }
160     }
161
162     static std::string sub_vars (const std::string & in, 
163             const std::map<std::string, std::string> & vars)
164     {
165         std::string out;
166         bool esc = false;
167         for (int i = 0; i < in.size(); ++i)
168         {
169             if (!esc && in[i] == '\\')
170             {
171                 esc = true;
172                 continue;
173             }
174             if (!esc && in[i] == '$') //var
175             {
176                 if (i+1 < in.size() && in[i+1] == '{') //ref prefix
177                 {
178                     ++i;
179                     std::string name;
180                     bool term = false;
181                     while (++i < in.size()) 
182                     {
183                         if (in[i] == '}') { term = true; break; }
184                         name += in[i];
185                     }
186                     if (!term) throw mp::filter::FilterException
187                         ("Unterminated var ref in '"+in+"' at "
188                          + boost::lexical_cast<std::string>(i));
189                     std::map<std::string, std::string>::const_iterator it
190                         = vars.find(name);
191                     if (it != vars.end())
192                         out += it->second;
193                 }
194                 else
195                 {
196                     throw mp::filter::FilterException
197                         ("Malformed or trimmed var ref in '"
198                          +in+"' at "+boost::lexical_cast<std::string>(i)); 
199                 }
200                 continue;
201             }
202             //passthru
203             out += in[i];
204             esc = false;
205         }
206         return out;
207     }
208     
209     void configure(
210             const std::string & req_uri_rx, 
211             const std::string & req_uri_pat, 
212             const std::string & resp_uri_rx, 
213             const std::string & resp_uri_pat) 
214     {
215        this->req_uri_rx = req_uri_rx;
216        this->req_uri_pat = req_uri_pat;
217        //pick up names
218        parse_groups(req_uri_rx, groups_by_num, groups_by_name);
219        this->resp_uri_rx = resp_uri_rx;
220        this->resp_uri_pat = resp_uri_pat;
221     };
222
223 private:
224     std::map<std::string, std::string> vars;
225     std::string req_uri_rx;
226     std::string resp_uri_rx;
227     std::string req_uri_pat;
228     std::string resp_uri_pat;
229     std::map<int, std::string> groups_by_num;
230     std::map<std::string, int> groups_by_name;
231
232 };
233
234
235 BOOST_AUTO_TEST_CASE( test_filter_rewrite_1 )
236 {
237     try
238     {
239        FilterHeaderRewrite fhr;
240     }
241     catch ( ... ) {
242         BOOST_CHECK (false);
243     }
244 }
245
246 BOOST_AUTO_TEST_CASE( test_filter_rewrite_2 )
247 {
248     try
249     {
250         mp::RouterChain router;
251
252         FilterHeaderRewrite fhr;
253         fhr.configure(
254                 ".*?(?P<host>[^:]+):(?P<port>\\d+).*",
255                 "http://${host}:${port}/somepath",
256                 ".*(localhost).*",
257                 "http:://g");
258         mp::filter::HTTPClient hc;
259         
260         router.append(fhr);
261         router.append(hc);
262
263         // create an http request
264         mp::Package pack;
265
266         mp::odr odr;
267         Z_GDU *gdu_req = z_get_HTTP_Request_uri(odr, 
268             "http://localhost:80/~jakub/targetsite.php", 0, 1);
269
270         pack.request() = gdu_req;
271
272         //feed to the router
273         pack.router(router).move();
274
275         //analyze the response
276         Z_GDU *gdu_res = pack.response().get();
277         BOOST_CHECK(gdu_res);
278         BOOST_CHECK_EQUAL(gdu_res->which, Z_GDU_HTTP_Response);
279         
280         Z_HTTP_Response *hres = gdu_res->u.HTTP_Response;
281         BOOST_CHECK(hres);
282
283     }
284     catch (std::exception & e) {
285         std::cout << e.what();
286         BOOST_CHECK (false);
287     }
288 }
289
290 /*
291  * Local variables:
292  * c-basic-offset: 4
293  * c-file-style: "Stroustrup"
294  * indent-tabs-mode: nil
295  * End:
296  * vim: shiftwidth=4 tabstop=8 expandtab
297  */
298