1 : // Copyright 2012 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/core/unittest_util.h"
19 : #include "syzygy/pe/decomposer.h"
20 : #include "syzygy/pe/pe_utils.h"
21 : #include "syzygy/pe/unittest_util.h"
22 :
23 : namespace pe {
24 : namespace transforms {
25 :
26 : using block_graph::BlockGraph;
27 : using core::RelativeAddress;
28 : typedef AddImportsTransform::ImportedModule ImportedModule;
29 :
30 : namespace {
31 :
32 : class AddImportsTransformTest : public testing::PELibUnitTest {
33 : public:
34 E : AddImportsTransformTest() : image_layout_(&block_graph_) {
35 E : }
36 :
37 E : virtual void SetUp() {
38 : base::FilePath image_path(
39 E : testing::GetExeRelativePath(testing::kTestDllName));
40 :
41 E : ASSERT_TRUE(pe_file_.Init(image_path));
42 :
43 : // Decompose the test image and look at the result.
44 E : Decomposer decomposer(pe_file_);
45 E : ASSERT_TRUE(decomposer.Decompose(&image_layout_));
46 :
47 : // Retrieve and validate the DOS header.
48 : dos_header_block_ =
49 E : image_layout_.blocks.GetBlockByAddress(RelativeAddress(0));
50 E : ASSERT_TRUE(dos_header_block_ != NULL);
51 E : ASSERT_TRUE(IsValidDosHeaderBlock(dos_header_block_));
52 E : }
53 :
54 : PEFile pe_file_;
55 : BlockGraph block_graph_;
56 : ImageLayout image_layout_;
57 : BlockGraph::Block* dos_header_block_;
58 : };
59 :
60 : // Given an ImportedModule tests that all of its symbols have been properly
61 : // processed.
62 E : void TestSymbols(const ImportedModule& module) {
63 E : for (size_t i = 0; i < module.size(); ++i) {
64 E : BlockGraph::Reference ref;
65 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
66 E : EXPECT_TRUE(ref.referenced() != NULL);
67 E : EXPECT_GE(ref.offset(), 0);
68 : EXPECT_LT(ref.offset(),
69 E : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
70 E : }
71 E : }
72 :
73 : } // namespace
74 :
75 E : TEST_F(AddImportsTransformTest, AddImportsExisting) {
76 E : ImportedModule module("export_dll.dll");
77 : size_t function1 = module.AddSymbol("function1",
78 E : ImportedModule::kAlwaysImport);
79 : size_t function3 = module.AddSymbol("function3",
80 E : ImportedModule::kAlwaysImport);
81 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
82 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
83 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
84 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function1));
85 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function3));
86 :
87 E : AddImportsTransform transform;
88 E : transform.AddModule(&module);
89 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
90 E : &transform, &block_graph_, dos_header_block_));
91 E : EXPECT_EQ(0u, transform.modules_added());
92 E : EXPECT_EQ(0u, transform.symbols_added());
93 :
94 E : EXPECT_TRUE(module.ModuleIsImported());
95 E : EXPECT_TRUE(module.SymbolIsImported(function1));
96 E : EXPECT_TRUE(module.SymbolIsImported(function3));
97 :
98 E : EXPECT_FALSE(module.ModuleWasAdded());
99 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
100 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
101 :
102 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
103 E : module.GetSymbolIatIndex(function1));
104 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
105 E : module.GetSymbolIatIndex(function3));
106 :
107 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
108 E : }
109 :
110 E : TEST_F(AddImportsTransformTest, AddImportsNewSymbol) {
111 E : ImportedModule module("export_dll.dll");
112 : size_t function1 = module.AddSymbol("function1",
113 E : ImportedModule::kAlwaysImport);
114 : size_t function3 = module.AddSymbol("function3",
115 E : ImportedModule::kAlwaysImport);
116 : size_t function4 = module.AddSymbol("function4",
117 E : ImportedModule::kAlwaysImport);
118 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
119 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
120 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
121 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
122 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function1));
123 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function3));
124 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function4));
125 :
126 E : EXPECT_TRUE(module.import_descriptor().block() == NULL);
127 :
128 E : AddImportsTransform transform;
129 E : transform.AddModule(&module);
130 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
131 E : &transform, &block_graph_, dos_header_block_));
132 E : EXPECT_EQ(0u, transform.modules_added());
133 E : EXPECT_EQ(1u, transform.symbols_added());
134 :
135 E : EXPECT_TRUE(module.import_descriptor().block() != NULL);
136 :
137 E : EXPECT_TRUE(module.ModuleIsImported());
138 E : EXPECT_TRUE(module.SymbolIsImported(function1));
139 E : EXPECT_TRUE(module.SymbolIsImported(function3));
140 E : EXPECT_TRUE(module.SymbolIsImported(function4));
141 :
142 E : EXPECT_FALSE(module.ModuleWasAdded());
143 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
144 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
145 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
146 :
147 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
148 E : module.GetSymbolIatIndex(function1));
149 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
150 E : module.GetSymbolIatIndex(function3));
151 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
152 E : module.GetSymbolIatIndex(function4));
153 :
154 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
155 :
156 : // TODO(chrisha): Write the image and try to load it!
157 E : }
158 :
159 E : TEST_F(AddImportsTransformTest, AddImportsNewModule) {
160 E : ImportedModule module("call_trace_client_rpc.dll");
161 : size_t indirect_penter = module.AddSymbol(
162 E : "_indirect_penter", ImportedModule::kAlwaysImport);
163 : size_t indirect_penter_dllmain = module.AddSymbol(
164 E : "_indirect_penter_dllmain", ImportedModule::kAlwaysImport);
165 : EXPECT_EQ("_indirect_penter",
166 E : module.GetSymbolName(indirect_penter));
167 : EXPECT_EQ("_indirect_penter_dllmain",
168 E : module.GetSymbolName(indirect_penter_dllmain));
169 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
170 : EXPECT_EQ(ImportedModule::kAlwaysImport,
171 E : module.GetSymbolMode(indirect_penter));
172 : EXPECT_EQ(ImportedModule::kAlwaysImport,
173 E : module.GetSymbolMode(indirect_penter_dllmain));
174 :
175 E : AddImportsTransform transform;
176 E : transform.AddModule(&module);
177 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
178 E : &transform, &block_graph_, dos_header_block_));
179 E : EXPECT_EQ(1u, transform.modules_added());
180 E : EXPECT_EQ(2u, transform.symbols_added());
181 :
182 E : EXPECT_TRUE(module.ModuleIsImported());
183 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter));
184 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter_dllmain));
185 :
186 E : EXPECT_TRUE(module.ModuleWasAdded());
187 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter));
188 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter_dllmain));
189 :
190 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
191 E : module.GetSymbolIatIndex(indirect_penter));
192 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
193 E : module.GetSymbolIatIndex(indirect_penter_dllmain));
194 :
195 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
196 :
197 : // TODO(chrisha): Write the image and try to load it!
198 E : }
199 :
200 E : TEST_F(AddImportsTransformTest, FindImportsExisting) {
201 E : ImportedModule module("export_dll.dll");
202 E : size_t function1 = module.AddSymbol("function1", ImportedModule::kFindOnly);
203 E : size_t function3 = module.AddSymbol("function3", ImportedModule::kFindOnly);
204 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
205 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
206 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
207 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
208 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
209 :
210 E : AddImportsTransform transform;
211 E : transform.AddModule(&module);
212 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
213 E : &transform, &block_graph_, dos_header_block_));
214 E : EXPECT_EQ(0u, transform.modules_added());
215 E : EXPECT_EQ(0u, transform.symbols_added());
216 :
217 E : EXPECT_TRUE(module.ModuleIsImported());
218 E : EXPECT_TRUE(module.SymbolIsImported(function1));
219 E : EXPECT_TRUE(module.SymbolIsImported(function3));
220 :
221 E : EXPECT_FALSE(module.ModuleWasAdded());
222 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
223 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
224 :
225 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
226 E : module.GetSymbolIatIndex(function1));
227 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
228 E : module.GetSymbolIatIndex(function3));
229 E : }
230 :
231 E : TEST_F(AddImportsTransformTest, FindImportsNewSymbol) {
232 E : ImportedModule module("export_dll.dll");
233 E : size_t function1 = module.AddSymbol("function1", ImportedModule::kFindOnly);
234 E : size_t function3 = module.AddSymbol("function3", ImportedModule::kFindOnly);
235 E : size_t function4 = module.AddSymbol("function4", ImportedModule::kFindOnly);
236 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
237 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
238 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
239 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
240 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
241 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
242 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function4));
243 :
244 E : AddImportsTransform transform;
245 E : transform.AddModule(&module);
246 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
247 E : &transform, &block_graph_, dos_header_block_));
248 E : EXPECT_EQ(0u, transform.modules_added());
249 E : EXPECT_EQ(0u, transform.symbols_added());
250 :
251 E : EXPECT_TRUE(module.ModuleIsImported());
252 E : EXPECT_TRUE(module.SymbolIsImported(function1));
253 E : EXPECT_TRUE(module.SymbolIsImported(function3));
254 E : EXPECT_FALSE(module.SymbolIsImported(function4));
255 :
256 E : EXPECT_FALSE(module.ModuleWasAdded());
257 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
258 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
259 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
260 :
261 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
262 E : module.GetSymbolIatIndex(function1));
263 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
264 E : module.GetSymbolIatIndex(function3));
265 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
266 E : module.GetSymbolIatIndex(function4));
267 E : }
268 :
269 E : TEST_F(AddImportsTransformTest, FindImportsNewModule) {
270 E : ImportedModule module("call_trace_client_rpc.dll");
271 : size_t indirect_penter = module.AddSymbol(
272 E : "_indirect_penter", ImportedModule::kFindOnly);
273 : size_t indirect_penter_dllmain = module.AddSymbol(
274 E : "_indirect_penter_dllmain", ImportedModule::kFindOnly);
275 : EXPECT_EQ("_indirect_penter",
276 E : module.GetSymbolName(indirect_penter));
277 : EXPECT_EQ("_indirect_penter_dllmain",
278 E : module.GetSymbolName(indirect_penter_dllmain));
279 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
280 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(indirect_penter));
281 : EXPECT_EQ(ImportedModule::kFindOnly,
282 E : module.GetSymbolMode(indirect_penter_dllmain));
283 :
284 E : AddImportsTransform transform;
285 E : transform.AddModule(&module);
286 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
287 E : &transform, &block_graph_, dos_header_block_));
288 E : EXPECT_EQ(0u, transform.modules_added());
289 E : EXPECT_EQ(0u, transform.symbols_added());
290 :
291 E : EXPECT_FALSE(module.ModuleIsImported());
292 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter));
293 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter_dllmain));
294 :
295 E : EXPECT_FALSE(module.ModuleWasAdded());
296 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter));
297 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter_dllmain));
298 :
299 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
300 E : module.GetSymbolIatIndex(indirect_penter));
301 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
302 E : module.GetSymbolIatIndex(indirect_penter_dllmain));
303 E : }
304 :
305 : } // namespace transforms
306 : } // namespace pe
|