1 : // Copyright 2012 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/pe_add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/block_graph/unittest_util.h"
19 : #include "syzygy/core/unittest_util.h"
20 : #include "syzygy/pe/decomposer.h"
21 : #include "syzygy/pe/pe_utils.h"
22 : #include "syzygy/pe/unittest_util.h"
23 :
24 : namespace pe {
25 : namespace transforms {
26 :
27 : using block_graph::BlockGraph;
28 : using core::RelativeAddress;
29 :
30 : namespace {
31 :
32 : class PEAddImportsTransformTest : public testing::PELibUnitTest {
33 : public:
34 : PEAddImportsTransformTest()
35 : : image_layout_(&block_graph_),
36 E : dos_header_block_(NULL) {
37 E : }
38 :
39 E : virtual void SetUp() {
40 : base::FilePath image_path(
41 E : testing::GetExeRelativePath(testing::kTestDllName));
42 :
43 E : ASSERT_TRUE(pe_file_.Init(image_path));
44 :
45 : // Decompose the test image and look at the result.
46 E : Decomposer decomposer(pe_file_);
47 E : ASSERT_TRUE(decomposer.Decompose(&image_layout_));
48 :
49 : // Retrieve and validate the DOS header.
50 : dos_header_block_ =
51 E : image_layout_.blocks.GetBlockByAddress(RelativeAddress(0));
52 E : ASSERT_TRUE(dos_header_block_ != NULL);
53 E : ASSERT_TRUE(IsValidDosHeaderBlock(dos_header_block_));
54 E : }
55 :
56 : PEFile pe_file_;
57 : testing::DummyTransformPolicy policy_;
58 : BlockGraph block_graph_;
59 : ImageLayout image_layout_;
60 : BlockGraph::Block* dos_header_block_;
61 : };
62 :
63 : // Given an ImportedModule tests that all of its symbols have been properly
64 : // processed.
65 E : void TestSymbols(const ImportedModule& module) {
66 E : for (size_t i = 0; i < module.size(); ++i) {
67 E : BlockGraph::Reference ref;
68 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
69 E : EXPECT_TRUE(ref.referenced() != NULL);
70 E : EXPECT_GE(ref.offset(), 0);
71 : EXPECT_LT(ref.offset(),
72 E : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
73 E : }
74 E : }
75 :
76 : } // namespace
77 :
78 E : TEST_F(PEAddImportsTransformTest, AddImportsExisting) {
79 E : ImportedModule module("export_dll.dll");
80 : size_t function1 = module.AddSymbol("function1",
81 E : ImportedModule::kAlwaysImport);
82 : size_t function3 = module.AddSymbol("function3",
83 E : ImportedModule::kAlwaysImport);
84 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
85 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
86 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
87 : EXPECT_EQ(ImportedModule::kAlwaysImport,
88 E : module.GetSymbolMode(function1));
89 : EXPECT_EQ(ImportedModule::kAlwaysImport,
90 E : module.GetSymbolMode(function3));
91 :
92 E : PEAddImportsTransform transform;
93 E : transform.AddModule(&module);
94 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
95 E : &transform, &policy_, &block_graph_, dos_header_block_));
96 E : EXPECT_EQ(0u, transform.modules_added());
97 E : EXPECT_EQ(0u, transform.symbols_added());
98 :
99 E : EXPECT_TRUE(module.ModuleIsImported());
100 E : EXPECT_TRUE(module.SymbolIsImported(function1));
101 E : EXPECT_TRUE(module.SymbolIsImported(function3));
102 :
103 E : EXPECT_FALSE(module.ModuleWasAdded());
104 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
105 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
106 :
107 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
108 E : module.GetSymbolImportIndex(function1));
109 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
110 E : module.GetSymbolImportIndex(function3));
111 :
112 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
113 E : }
114 :
115 E : TEST_F(PEAddImportsTransformTest, AddImportsNewSymbol) {
116 E : ImportedModule module("export_dll.dll");
117 : size_t function1 = module.AddSymbol("function1",
118 E : ImportedModule::kAlwaysImport);
119 : size_t function3 = module.AddSymbol("function3",
120 E : ImportedModule::kAlwaysImport);
121 : size_t function4 = module.AddSymbol("function4",
122 E : ImportedModule::kAlwaysImport);
123 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
124 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
125 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
126 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
127 : EXPECT_EQ(ImportedModule::kAlwaysImport,
128 E : module.GetSymbolMode(function1));
129 : EXPECT_EQ(ImportedModule::kAlwaysImport,
130 E : module.GetSymbolMode(function3));
131 : EXPECT_EQ(ImportedModule::kAlwaysImport,
132 E : module.GetSymbolMode(function4));
133 :
134 E : PEAddImportsTransform transform;
135 E : transform.AddModule(&module);
136 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
137 E : &transform, &policy_, &block_graph_, dos_header_block_));
138 E : EXPECT_EQ(0u, transform.modules_added());
139 E : EXPECT_EQ(1u, transform.symbols_added());
140 :
141 E : EXPECT_TRUE(module.ModuleIsImported());
142 E : EXPECT_TRUE(module.SymbolIsImported(function1));
143 E : EXPECT_TRUE(module.SymbolIsImported(function3));
144 E : EXPECT_TRUE(module.SymbolIsImported(function4));
145 :
146 E : EXPECT_FALSE(module.ModuleWasAdded());
147 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
148 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
149 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
150 :
151 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
152 E : module.GetSymbolImportIndex(function1));
153 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
154 E : module.GetSymbolImportIndex(function3));
155 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
156 E : module.GetSymbolImportIndex(function4));
157 :
158 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
159 :
160 : // TODO(chrisha): Write the image and try to load it!
161 E : }
162 :
163 E : TEST_F(PEAddImportsTransformTest, AddImportsNewModule) {
164 E : ImportedModule module("call_trace_client_rpc.dll");
165 : size_t indirect_penter = module.AddSymbol(
166 E : "_indirect_penter", ImportedModule::kAlwaysImport);
167 : size_t indirect_penter_dllmain = module.AddSymbol(
168 E : "_indirect_penter_dllmain", ImportedModule::kAlwaysImport);
169 : EXPECT_EQ("_indirect_penter",
170 E : module.GetSymbolName(indirect_penter));
171 : EXPECT_EQ("_indirect_penter_dllmain",
172 E : module.GetSymbolName(indirect_penter_dllmain));
173 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
174 : EXPECT_EQ(ImportedModule::kAlwaysImport,
175 E : module.GetSymbolMode(indirect_penter));
176 : EXPECT_EQ(ImportedModule::kAlwaysImport,
177 E : module.GetSymbolMode(indirect_penter_dllmain));
178 :
179 E : PEAddImportsTransform transform;
180 E : transform.AddModule(&module);
181 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
182 E : &transform, &policy_, &block_graph_, dos_header_block_));
183 E : EXPECT_EQ(1u, transform.modules_added());
184 E : EXPECT_EQ(2u, transform.symbols_added());
185 :
186 E : EXPECT_TRUE(module.ModuleIsImported());
187 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter));
188 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter_dllmain));
189 :
190 E : EXPECT_TRUE(module.ModuleWasAdded());
191 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter));
192 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter_dllmain));
193 :
194 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
195 E : module.GetSymbolImportIndex(indirect_penter));
196 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
197 E : module.GetSymbolImportIndex(indirect_penter_dllmain));
198 :
199 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
200 :
201 : // TODO(chrisha): Write the image and try to load it!
202 E : }
203 :
204 E : TEST_F(PEAddImportsTransformTest, FindImportsExisting) {
205 E : ImportedModule module("export_dll.dll");
206 : size_t function1 = module.AddSymbol("function1",
207 E : ImportedModule::kFindOnly);
208 : size_t function3 = module.AddSymbol("function3",
209 E : ImportedModule::kFindOnly);
210 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
211 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
212 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
213 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
214 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
215 :
216 E : PEAddImportsTransform transform;
217 E : transform.AddModule(&module);
218 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
219 E : &transform, &policy_, &block_graph_, dos_header_block_));
220 E : EXPECT_EQ(0u, transform.modules_added());
221 E : EXPECT_EQ(0u, transform.symbols_added());
222 :
223 E : EXPECT_TRUE(module.ModuleIsImported());
224 E : EXPECT_TRUE(module.SymbolIsImported(function1));
225 E : EXPECT_TRUE(module.SymbolIsImported(function3));
226 :
227 E : EXPECT_FALSE(module.ModuleWasAdded());
228 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
229 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
230 :
231 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
232 E : module.GetSymbolImportIndex(function1));
233 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
234 E : module.GetSymbolImportIndex(function3));
235 E : }
236 :
237 E : TEST_F(PEAddImportsTransformTest, FindImportsNewSymbol) {
238 E : ImportedModule module("export_dll.dll");
239 : size_t function1 = module.AddSymbol("function1",
240 E : ImportedModule::kFindOnly);
241 : size_t function3 = module.AddSymbol("function3",
242 E : ImportedModule::kFindOnly);
243 : size_t function4 = module.AddSymbol("function4",
244 E : ImportedModule::kFindOnly);
245 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
246 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
247 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
248 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
249 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
250 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
251 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function4));
252 :
253 E : PEAddImportsTransform transform;
254 E : transform.AddModule(&module);
255 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
256 E : &transform, &policy_, &block_graph_, dos_header_block_));
257 E : EXPECT_EQ(0u, transform.modules_added());
258 E : EXPECT_EQ(0u, transform.symbols_added());
259 :
260 E : EXPECT_TRUE(module.ModuleIsImported());
261 E : EXPECT_TRUE(module.SymbolIsImported(function1));
262 E : EXPECT_TRUE(module.SymbolIsImported(function3));
263 E : EXPECT_FALSE(module.SymbolIsImported(function4));
264 :
265 E : EXPECT_FALSE(module.ModuleWasAdded());
266 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
267 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
268 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
269 :
270 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
271 E : module.GetSymbolImportIndex(function1));
272 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
273 E : module.GetSymbolImportIndex(function3));
274 : EXPECT_EQ(ImportedModule::kInvalidImportIndex,
275 E : module.GetSymbolImportIndex(function4));
276 E : }
277 :
278 E : TEST_F(PEAddImportsTransformTest, FindImportsNewModule) {
279 E : ImportedModule module("call_trace_client_rpc.dll");
280 : size_t indirect_penter = module.AddSymbol(
281 E : "_indirect_penter", ImportedModule::kFindOnly);
282 : size_t indirect_penter_dllmain = module.AddSymbol(
283 E : "_indirect_penter_dllmain", ImportedModule::kFindOnly);
284 : EXPECT_EQ("_indirect_penter",
285 E : module.GetSymbolName(indirect_penter));
286 : EXPECT_EQ("_indirect_penter_dllmain",
287 E : module.GetSymbolName(indirect_penter_dllmain));
288 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
289 : EXPECT_EQ(ImportedModule::kFindOnly,
290 E : module.GetSymbolMode(indirect_penter));
291 : EXPECT_EQ(ImportedModule::kFindOnly,
292 E : module.GetSymbolMode(indirect_penter_dllmain));
293 :
294 E : PEAddImportsTransform transform;
295 E : transform.AddModule(&module);
296 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
297 E : &transform, &policy_, &block_graph_, dos_header_block_));
298 E : EXPECT_EQ(0u, transform.modules_added());
299 E : EXPECT_EQ(0u, transform.symbols_added());
300 :
301 E : EXPECT_FALSE(module.ModuleIsImported());
302 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter));
303 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter_dllmain));
304 :
305 E : EXPECT_FALSE(module.ModuleWasAdded());
306 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter));
307 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter_dllmain));
308 :
309 : EXPECT_EQ(ImportedModule::kInvalidImportIndex,
310 E : module.GetSymbolImportIndex(indirect_penter));
311 : EXPECT_EQ(ImportedModule::kInvalidImportIndex,
312 E : module.GetSymbolImportIndex(indirect_penter_dllmain));
313 E : }
314 :
315 : } // namespace transforms
316 : } // namespace pe
|