1 : // Copyright 2013 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/coff_add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/block_graph/unittest_util.h"
19 : #include "syzygy/pe/pe_utils.h"
20 : #include "syzygy/pe/unittest_util.h"
21 :
22 : namespace pe {
23 : namespace transforms {
24 :
25 : using block_graph::BlockGraph;
26 :
27 : namespace {
28 :
29 : class CoffAddImportsTransformTest : public testing::CoffUnitTest {
30 : public:
31 E : virtual void SetUp() OVERRIDE {
32 E : testing::CoffUnitTest::SetUp();
33 E : }
34 :
35 : // Check that symbols in @p module have been assigned a reference, and that
36 : // they pass through a round-trip writing and decomposition.
37 E : void TestSymbols(const ImportedModule& module) {
38 : // Check resulting references.
39 E : for (size_t i = 0; i < module.size(); ++i) {
40 E : BlockGraph::Reference ref;
41 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
42 E : EXPECT_TRUE(ref.referenced() != NULL);
43 E : EXPECT_GE(ref.offset(), 0);
44 E : EXPECT_LT(ref.offset(),
45 : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
46 E : }
47 :
48 E : ASSERT_NO_FATAL_FAILURE(TestRoundTrip());
49 E : }
50 : };
51 :
52 : const char kFunction1Name[] = "__imp_?function1@@YAHXZ";
53 : const char kFunction3Name[] = "?function3@@YAHXZ";
54 : const char kFunction4Name[] = "?function4@@YAHXZ";
55 : const char kMemcpy[] = "_memset"; // Multiply defined.
56 :
57 : } // namespace
58 :
59 E : TEST_F(CoffAddImportsTransformTest, AddImportsExisting) {
60 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
61 E : ImportedModule module("export_dll.dll");
62 : size_t function1 = module.AddSymbol(kFunction1Name,
63 E : ImportedModule::kAlwaysImport);
64 : size_t function3 = module.AddSymbol(kFunction3Name,
65 E : ImportedModule::kAlwaysImport);
66 :
67 E : CoffAddImportsTransform transform;
68 E : transform.AddModule(&module);
69 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
70 E : &transform, &policy_, &block_graph_, headers_block_));
71 E : EXPECT_EQ(0u, transform.modules_added());
72 E : EXPECT_EQ(0u, transform.symbols_added());
73 :
74 E : EXPECT_TRUE(module.ModuleIsImported());
75 E : EXPECT_TRUE(module.SymbolIsImported(function1));
76 E : EXPECT_TRUE(module.SymbolIsImported(function3));
77 :
78 E : EXPECT_FALSE(module.ModuleWasAdded());
79 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
80 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
81 :
82 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
83 E : }
84 :
85 E : TEST_F(CoffAddImportsTransformTest, AddImportsNewSymbol) {
86 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
87 E : ImportedModule module("export_dll.dll");
88 : size_t function1 = module.AddSymbol(kFunction1Name,
89 E : ImportedModule::kAlwaysImport);
90 : size_t function3 = module.AddSymbol(kFunction3Name,
91 E : ImportedModule::kAlwaysImport);
92 : size_t function4 = module.AddSymbol(kFunction4Name,
93 E : ImportedModule::kAlwaysImport);
94 :
95 E : CoffAddImportsTransform transform;
96 E : transform.AddModule(&module);
97 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
98 E : &transform, &policy_, &block_graph_, headers_block_));
99 E : EXPECT_EQ(0u, transform.modules_added());
100 E : EXPECT_EQ(1u, transform.symbols_added());
101 :
102 E : EXPECT_TRUE(module.ModuleIsImported());
103 E : EXPECT_TRUE(module.SymbolIsImported(function1));
104 E : EXPECT_TRUE(module.SymbolIsImported(function3));
105 E : EXPECT_TRUE(module.SymbolIsImported(function4));
106 :
107 E : EXPECT_FALSE(module.ModuleWasAdded());
108 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
109 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
110 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
111 :
112 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
113 E : }
114 :
115 E : TEST_F(CoffAddImportsTransformTest, FindImportsExistingMultiple) {
116 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
117 E : ImportedModule module("export_dll.dll");
118 : size_t function1 = module.AddSymbol(kFunction1Name,
119 E : ImportedModule::kFindOnly);
120 : size_t function3 = module.AddSymbol(kFunction3Name,
121 E : ImportedModule::kFindOnly);
122 :
123 E : CoffAddImportsTransform transform;
124 E : transform.AddModule(&module);
125 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
126 E : &transform, &policy_, &block_graph_, headers_block_));
127 E : EXPECT_EQ(0u, transform.modules_added());
128 E : EXPECT_EQ(0u, transform.symbols_added());
129 :
130 E : EXPECT_TRUE(module.ModuleIsImported());
131 E : EXPECT_TRUE(module.SymbolIsImported(function1));
132 E : EXPECT_TRUE(module.SymbolIsImported(function3));
133 :
134 E : EXPECT_FALSE(module.ModuleWasAdded());
135 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
136 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
137 E : }
138 :
139 E : TEST_F(CoffAddImportsTransformTest, FindImportsNewSymbol) {
140 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
141 E : ImportedModule module("export_dll.dll");
142 : size_t function1 = module.AddSymbol(kFunction1Name,
143 E : ImportedModule::kFindOnly);
144 : size_t function3 = module.AddSymbol(kFunction3Name,
145 E : ImportedModule::kFindOnly);
146 : size_t function4 = module.AddSymbol(kFunction4Name,
147 E : ImportedModule::kFindOnly);
148 :
149 E : CoffAddImportsTransform transform;
150 E : transform.AddModule(&module);
151 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
152 E : &transform, &policy_, &block_graph_, headers_block_));
153 E : EXPECT_EQ(0u, transform.modules_added());
154 E : EXPECT_EQ(0u, transform.symbols_added());
155 :
156 E : EXPECT_TRUE(module.ModuleIsImported());
157 E : EXPECT_TRUE(module.SymbolIsImported(function1));
158 E : EXPECT_TRUE(module.SymbolIsImported(function3));
159 E : EXPECT_FALSE(module.SymbolIsImported(function4));
160 :
161 E : EXPECT_FALSE(module.ModuleWasAdded());
162 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
163 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
164 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
165 E : }
166 :
167 E : TEST_F(CoffAddImportsTransformTest, EmptyStringTable) {
168 : // Override with a different module.
169 : test_dll_obj_path_ = testing::GetSrcRelativePath(
170 E : testing::kEmptyStringTableCoffName);
171 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
172 E : ImportedModule module("export_dll.dll");
173 : size_t function1 = module.AddSymbol(kFunction1Name,
174 E : ImportedModule::kAlwaysImport);
175 :
176 E : CoffAddImportsTransform transform;
177 E : transform.AddModule(&module);
178 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
179 E : &transform, &policy_, &block_graph_, headers_block_));
180 E : EXPECT_EQ(0u, transform.modules_added());
181 E : EXPECT_EQ(1u, transform.symbols_added());
182 :
183 E : EXPECT_TRUE(module.ModuleIsImported());
184 E : EXPECT_TRUE(module.SymbolIsImported(function1));
185 :
186 E : EXPECT_FALSE(module.ModuleWasAdded());
187 E : EXPECT_TRUE(module.SymbolWasAdded(function1));
188 E : }
189 :
190 :
191 : } // namespace transforms
192 : } // namespace pe
|